[RefBackend] Fix leaks related to ABI boundaries.

Best as I can tell (e.g. from LeakSanitizer), this fixes all the leaks
except for those due to buffers created internally to the codegenned
code itself (up next I'll add the buffer deallocation pass to fix
those).

The main change is that instead of attempting to pass `refbackrt::Tensor`
to the codegenned function directly, we make all the ABI types be
UnrankedMemRef which gets passed awkwardly (but workably) as a
`{size_t rank, void *ptrToDescriptor}` on the ABI. The reason why
refbackrt::Tensor wasn't workable is that is that MLIR doesn't really
have a way to deal with the lifetime of unranked memref descriptors that
happen inside the function, which is inevitably what would happen in the
old code that would emit runtime calls to
`refbackrt.to_memref/refbackrt.from_memref` to convert back and forth to
`refbackrt::Tensor` inside the codegenned code.

So, instead of the `refbackrt.to_memref/refbackrt.from_memref` with no
real sound basis for valid lifetime management, we now have a lovely
piece of code in `refbackrt::invoke` in `Runtime.cpp` that just barely
seems to be sound. We rely on the codegenned code having these
properties, which it seems to have:

- it won't free memref descriptors or their backing buffer for arguments
  of UnrankedMemRef type.

- it will allocate a separate memref descriptor for each result
  UnrankedMemRef (which is ensured by having a separate memref_cast for
  each)

- we can sniff the `allocatedPtr`'s (i.e. the backing buffer pointers)
  to avoid double-freeing in the case of aliasing of the backing buffer
  (including backing buffers for arguments feeding into results)

- to catch the case of statically allocated data (which we need to avoid
  passing to `free`) , check if the `allocatedPtr` is (no joke) equal to
  `0xDEADBEEF`, because there is otherwise no way to distinguish
  statically allocated from malloc'ed data...  (std.global_memref lowering
  to LLVM by happenstance sets the allocatedPtr equal to `0xDEADBEEF`,
  presumably mainly as a debugging thing)

Even with all this, we *still* need to (internally to refbackrt::invoke)
make copies of all inputs/outputs! And the details of how the LLVM-level
ABI gets laid out for e.g. function arguments/returns is still super
tricky.

This really highlights how deficient memref is as the general runtime
type for our use case. It's stewing in my mind how best to improve the
situation. My general gut feeling is that IREE's abstractions for this
are "right", but I need to think more how to distill those aspects of
IREE's design in a "reference" way for RefBackend.

Some implementation notes:

- In terms of how this is implemented, this did catch a bug in our ABI
  wrapper functions in LowerToLLVM.cpp, which I had to fix (it happened to
  work before through some combination of npcomprt::Tensor being passed as
  a single pointer + probably me infinite-monkey-ing it until it worked)

- This actually removes 2 out of the 3 compiler runtime functions (the
  only one left is "abort_if". (most of the memref descriptor code moved
  from CopmilerRuntime.cpp to Runtime.cpp)

  - this also means deleting `refbackrt.from_memref` and
  `refbackrt.to_memref`
pull/136/head
Sean Silva 2020-11-24 17:18:57 -08:00
parent 699bf5df45
commit 46aa6d0a24
12 changed files with 251 additions and 411 deletions

View File

@ -23,14 +23,4 @@ lowered to the llvm dialect.
}];
}
def Refbackrt_Tensor
: DialectType<
Refbackrt_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::refbackrt::TensorType>()">,
"refbackrt.tensor">,
BuildableType<
"$_builder.getType<::mlir::NPCOMP::refbackrt::TensorType>()"> {
let typeDescription = [{The runtime type that represents a buffer.}];
}
#endif // #ifndef REFBACKRT_BASE

View File

@ -16,26 +16,6 @@ class Refbackrt_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Refbackrt_Dialect, mnemonic, traits> {
}
def Refbackrt_ToMemrefOp : Refbackrt_Op<"to_memref"> {
let summary = "Gets a memref descriptor from a tensor";
let description = [{
Gets a memref descriptor from a tensor.
}];
let arguments = (ins Refbackrt_Tensor:$tensor);
let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
}
def Refbackrt_FromMemrefOp : Refbackrt_Op<"from_memref"> {
let summary = "Converts a memref descriptor to a tensor";
let description = [{
Copies the data from a memref into a new tensor.
}];
let arguments = (ins AnyUnrankedMemRef:$memref);
let results = (outs Refbackrt_Tensor:$tensor);
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
let summary = "Aborts if the predicate is true";
let description = [{

View File

@ -70,6 +70,8 @@ public:
return ret;
}
int debugGetRefCount() { return ptr->refCount; }
private:
static void incref(T *ptr) {
if (!ptr)

View File

@ -21,23 +21,3 @@ void RefbackrtDialect::initialize() {
>();
addTypes<TensorType>();
}
Type RefbackrtDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "tensor")
return TensorType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown type in 'refbackrt' dialect: ")
<< keyword;
return Type();
}
void RefbackrtDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<NPCOMP::refbackrt::TensorType>([&](Type) { os << "tensor"; })
.Default(
[&](Type) { llvm_unreachable("unexpected 'refbackrt' type kind"); });
}

View File

@ -77,35 +77,6 @@ public:
};
} // namespace
namespace {
// FromMemrefOp requires special handling so that the unranked memref descriptor
// gets passed as two separate arguments instead of as a struct.
class FromMemrefOpCompilerRuntimeLowering
: public OpConversionPattern<refbackrt::FromMemrefOp> {
public:
FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<refbackrt::FromMemrefOp>(backingFunc.getContext()),
backingFunc(backingFunc) {}
LogicalResult
matchAndRewrite(refbackrt::FromMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto structVal = operands[0];
Value rank = rewriter.create<LLVM::ExtractValueOp>(
op.getLoc(),
structVal.getType().cast<LLVMType>().getStructElementType(0), structVal,
rewriter.getI32ArrayAttr({0}));
Value descriptorPtr = rewriter.create<LLVM::ExtractValueOp>(
op.getLoc(),
structVal.getType().cast<LLVMType>().getStructElementType(1), structVal,
rewriter.getI32ArrayAttr({1}));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, backingFunc, ValueRange({rank, descriptorPtr}));
return success();
}
LLVM::LLVMFuncOp backingFunc;
};
} // namespace
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
OpBuilder &builder, Location loc) {
// TODO: Deduplicate strings.
@ -188,35 +159,6 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
"abort_if", abortIfFuncTy, builder, module.getLoc());
patterns.insert<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
}
auto convertFunctionType = [&](FunctionType type) {
TypeConverter::SignatureConversion conversion(type.getNumInputs());
return typeConverter.convertFunctionSignature(type, /*isVariadic=*/false,
conversion);
};
{
auto mlirFunctionType = builder.getFunctionType(
{builder.getType<refbackrt::TensorType>()},
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl(
"to_memref", funcTy, builder, module.getLoc());
patterns.insert<TrivialCompilerRuntimeLowering<refbackrt::ToMemrefOp>>(
toMemrefFunc);
}
{
// TODO: Pass in an element type enum, since the unranked memref descriptor
// doesn't know its own dtype.
auto mlirFunctionType = builder.getFunctionType(
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)},
{builder.getType<refbackrt::TensorType>()});
LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl(
"from_memref", funcTy, builder, module.getLoc());
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
}
}
//===----------------------------------------------------------------------===//
@ -390,9 +332,12 @@ static Value getTypedAddressFromVoidStarStar(Value voidStarStar, int32_t index,
Value ci = builder.create<LLVM::ConstantOp>(
loc, LLVMType::getIntNTy(builder.getContext(), 32),
builder.getI32IntegerAttr(index));
auto inputPtr = builder.create<LLVM::GEPOp>(
loc, LLVMType::getInt8PtrTy(builder.getContext()), voidStarStar,
ValueRange(ci));
// Do `voidStarStar[i]` as a gep + load.
auto inputPtrAddr = builder.create<LLVM::GEPOp>(
loc, LLVMType::getInt8PtrTy(builder.getContext()).getPointerTo(),
voidStarStar, ValueRange(ci));
auto inputPtr = builder.create<LLVM::LoadOp>(loc, inputPtrAddr);
return builder.create<LLVM::BitcastOp>(loc, ty.getPointerTo(), inputPtr);
}
@ -409,6 +354,21 @@ static SmallVector<Value, 6> loadCallArgs(Value inputsPtrPtr, LLVMType funcTy,
return callArgs;
}
static LLVM::LLVMType getUnrankedMemrefDescriptorType(MLIRContext *context) {
LLVMTypeConverter converter(context);
// LLVMTypeConverter doesn't directly expose the struct type used to represent
// unranked memrefs on ABI boundaries. To get that type, we convert
// an unranked memref type and see what it produces.
//
// An unranked memref is just a size_t for the rank and an void* pointer to
// descriptor, so the choice of element type here is arbitrary -- it all
// converts to the same thing.
return converter
.convertType(UnrankedMemRefType::get(Float32Type::get(context),
/*memorySpace=*/0))
.cast<LLVM::LLVMType>();
}
// Writes out the logical results of the wrapper function through the void**
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
// only supports a single return type (or void/no results), the logic here needs
@ -424,13 +384,15 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
Value result = callToWrapped.getResult(0);
auto ty = result.getType().cast<LLVMType>();
// 1 logical result.
if (!ty.isStructTy()) {
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
}
// >=2 logical results.
assert(ty.isStructTy() && "must be a multi-result packed struct!");
// >=2 logical results. The convention linked above will create a struct
// wrapping.
for (int i = 0, e = ty.getStructNumElements(); i < e; i++) {
auto elementTy = ty.getStructElementType(i);
Value addr = getTypedAddressFromVoidStarStar(resultsPtrPtr, i, elementTy,
@ -492,11 +454,6 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
LLVMTypeConverter converter(context);
// refbackrt::TensorType is passed as a `void*` in the ABI.
converter.addConversion([&](refbackrt::TensorType type) {
return LLVMType::getInt8PtrTy(context);
});
OwningRewritePatternList patterns;
LLVMConversionTarget target(*context);
target.addDynamicallyLegalOp<FuncOp>(

View File

@ -97,7 +97,8 @@ public:
} // namespace
namespace {
// At ABI bondaries, use !refbackrt.tensor instead of memref.
// At ABI boundaries, convert all memrefs to unranked memrefs so that they have
// a fixed ABI.
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -128,9 +129,7 @@ public:
for (auto newAndOldArg :
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
std::tie(newArg, oldArg) = newAndOldArg;
auto abiMemref = rewriter.create<refbackrt::ToMemrefOp>(
op.getLoc(), getABIMemrefType(oldArg.getType()), newArg);
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref,
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), newArg,
oldArg.getType());
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
}
@ -141,7 +140,7 @@ public:
} // namespace
namespace {
// At the return ABI boundaries, convert to !refbackrt.tensor type.
// At the return ABI boundaries, convert to the ABI type.
// This pattern is needed to trigger the type conversion mechanics to do a
// target materialization.
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
@ -161,16 +160,14 @@ static LogicalResult doDialectConversion(ModuleOp module) {
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion([](MemRefType type) {
return refbackrt::TensorType::get(type.getContext());
});
typeConverter.addConversion(
[](MemRefType type) { return getABIMemrefType(type); });
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, refbackrt::TensorType type, ValueRange inputs,
[](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
auto abiMemref = builder.create<MemRefCastOp>(
return builder.create<MemRefCastOp>(
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
return builder.create<refbackrt::FromMemrefOp>(loc, type, abiMemref);
});
OwningRewritePatternList patterns;

View File

@ -26,91 +26,3 @@ extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
std::exit(1);
}
}
namespace {
// These definitions are based on the ones in
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
// sync.
//
// Those definitions are flawed though because they are overly templated.
struct MemrefDescriptor {
void *allocatedPtr;
void *dataPtr;
std::int64_t offset;
// Tail-allocated int64_t sizes followed by strides.
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail, assumedRank);
}
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
}
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
// default striding.
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
};
struct UnrankedMemref {
int64_t rank;
MemrefDescriptor *descriptor;
};
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
void *data) {
auto rank = extents.size();
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
descriptor->allocatedPtr = data;
descriptor->dataPtr = data;
descriptor->offset = 0;
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
// calculating the strides for a standard dense layout.
std::int64_t stride = 1;
for (int i = 0, e = rank; i < e; i++) {
auto revIdx = e - i - 1;
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
descriptor->getStrides(rank)[revIdx] = stride;
stride *= extents[revIdx];
}
return descriptor;
}
std::int32_t getNumElements(MemrefDescriptor *descriptor, int assumedRank) {
if (assumedRank == 0)
return 1;
return descriptor->getSizes(assumedRank)[0] *
descriptor->getStrides(assumedRank)[0];
}
} // namespace
extern "C" UnrankedMemref __npcomp_compiler_rt_to_memref(Tensor *tensor) {
auto byteSize = tensor->getDataByteSize();
void *data = std::malloc(byteSize);
std::memcpy(data, tensor->getData(), byteSize);
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
return UnrankedMemref{tensor->getRank(), descriptor};
}
extern "C" Tensor *
__npcomp_compiler_rt_from_memref(std::int64_t rank,
MemrefDescriptor *descriptor) {
auto numElements = getNumElements(descriptor, rank);
// TODO: Have the compiler pass this as an argument.
auto elementType = ElementType::F32;
auto byteSize = getElementTypeByteSize(elementType) * numElements;
void *data = std::malloc(byteSize);
std::memcpy(data, descriptor->dataPtr, byteSize);
auto extents64 = descriptor->getSizes(rank);
// Launder from std::int64_t to std::int32_t.
constexpr int kMaxRank = 20;
std::array<std::int32_t, kMaxRank> extents32Buf;
for (int i = 0, e = extents64.size(); i < e; i++)
extents32Buf[i] = extents64[i];
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, data);
}

View File

@ -17,6 +17,91 @@
using namespace refbackrt;
//===----------------------------------------------------------------------===//
// Memref descriptors for interacting with MLIR codegenerated code.
//===----------------------------------------------------------------------===//
namespace {
// These definitions are based on the ones in
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
// sync.
//
// Those definitions are flawed though because they are overly templated.
struct MemrefDescriptor {
void *allocatedPtr;
void *dataPtr;
std::int64_t offset;
// Tail-allocated int64_t sizes followed by strides.
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail, assumedRank);
}
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
}
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
// default striding.
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
// Returns the number of elements in this MemrefDescriptor, assuming this
// descriptor has rank `assumedRank`.
std::int32_t getNumElements(int assumedRank) {
if (assumedRank == 0)
return 1;
return getSizes(assumedRank)[0] * getStrides(assumedRank)[0];
}
};
} // namespace
namespace {
struct UnrankedMemref {
int64_t rank;
MemrefDescriptor *descriptor;
};
} // namespace
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
void *data) {
auto rank = extents.size();
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
descriptor->allocatedPtr = data;
descriptor->dataPtr = data;
descriptor->offset = 0;
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
// calculating the strides for a standard dense layout.
std::int64_t stride = 1;
for (int i = 0, e = rank; i < e; i++) {
auto revIdx = e - i - 1;
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
descriptor->getStrides(rank)[revIdx] = stride;
stride *= extents[revIdx];
}
return descriptor;
}
static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) {
auto byteSize = tensor->getDataByteSize();
void *data = std::malloc(byteSize);
std::memcpy(data, tensor->getData(), byteSize);
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
return UnrankedMemref{tensor->getRank(), descriptor};
}
static Tensor *convertUnrankedMemrefToRefbackrtTensor(
std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) {
// Launder from std::int64_t to std::int32_t.
auto extents64 = descriptor->getSizes(rank);
constexpr int kMaxRank = 20;
std::array<std::int32_t, kMaxRank> extents32Buf;
for (int i = 0, e = extents64.size(); i < e; i++)
extents32Buf[i] = extents64[i];
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, descriptor->dataPtr);
}
//===----------------------------------------------------------------------===//
// Tensor
//===----------------------------------------------------------------------===//
@ -93,25 +178,71 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
std::array<Tensor *, kMaxArity> inputTensorPtrs;
std::array<Tensor *, kMaxArity> outputTensorPtrs;
std::array<void *, kMaxArity> packedInputs;
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
std::array<void *, kMaxArity * 2> packedInputs;
std::array<void *, kMaxArity> packedOutputs;
for (int i = 0, e = inputs.size(); i < e; i++)
inputTensorPtrs[i] = inputs[i].get();
for (int i = 0, e = inputs.size(); i < e; i++)
packedInputs[i] = ToVoidPtr(inputTensorPtrs[i]);
for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].get());
}
for (int i = 0, e = inputs.size(); i < e; i++) {
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
}
for (int i = 0, e = outputs.size(); i < e; i++) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
}
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
for (int i = 0, e = outputs.size(); i < e; i++)
outputTensorPtrs[i] = static_cast<Tensor *>(packedOutputs[i]);
// TODO: Actually manage refcounts inside the compiler.
// Right now, we only pass around refbackrt.tensor's in trivial ways on ABI
// boundaries, so the following contract of the compiler-generated code works:
// - input tensors are never retained or released
// - output tensors always have refcount 0. Hence the next line here is
// actually essential because it increments the refcounts so they are nonzero.
for (int i = 0, e = outputs.size(); i < e; i++)
outputs[i] = Ref<Tensor>(outputTensorPtrs[i]);
for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata.
auto elementType = ElementType::F32;
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = Ref<Tensor>(tensor);
}
for (int i = 0, e = outputs.size(); i < e; i++) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying
// malloc allocation. Do a linear scan to see if any of the previously
// deallocated buffers already freed this pointer.
bool bufferNeedsFreeing = true;
for (int j = 0; j < i; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
for (int i = 0, e = inputs.size(); i < e; i++) {
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true;
for (int j = 0, je = outputs.size(); j < je; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
// HACK: The returned memref can point into statically allocated memory that
// we can't pass to `free`, such as the result of lowering a tensor-valued
// `std.constant` to `std.global_memref`. The LLVM lowering of
// std.global_memref sets the allocated pointer to the magic value
// 0xDEADBEEF, which we sniff for here. This is yet another strong signal
// that memref is really not the right abstraction for ABI's.
if (reinterpret_cast<std::intptr_t>(allocatedPtr) == 0xDEADBEEF)
bufferNeedsFreeing = false;
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
for (int i = 0, e = outputs.size(); i < e; i++) {
// The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's.
std::free(outputUnrankedMemrefs[i].descriptor);
}
for (int i = 0, e = inputs.size(); i < e; i++) {
std::free(inputUnrankedMemrefs[i].descriptor);
}
}
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,

View File

@ -6,8 +6,6 @@ refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// CHECK-LABEL: func @f
// CHECK-SAME: !refbackrt.tensor
func @f(%arg0: !refbackrt.tensor) {
func @f(%arg0: memref<*xf32>) {
return
}

View File

@ -1,118 +1,75 @@
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_identity(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_6:.*]] = llvm.call @identity(%[[VAL_5]]) : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr<ptr<i8>>
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_identity("identity")
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(8 : i32) : !llvm.i32
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr<array<8 x i8>>
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<8 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_identity : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: llvm.return %[[VAL_13]] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr<array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>> to !llvm.ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: }
refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32}
}
// CHECK-LABEL: llvm.func @identity(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
// CHECK: }
func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
return %arg0 : !refbackrt.tensor
}
// -----
// Test input/output arg marshaling.
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_6:.*]] = llvm.call @inputs1results2(%[[VAL_5]]) : (!llvm.ptr<i8>) -> !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_6]][0 : i32] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: llvm.store %[[VAL_10]], %[[VAL_9]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_11]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_13:.*]] = llvm.bitcast %[[VAL_12]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_6]][1 : i32] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: llvm.store %[[VAL_14]], %[[VAL_13]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_6:.*]] = llvm.call @inputs1results1(%[[VAL_5]]) : (!llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr<ptr<i8>>
// CHECK: llvm.call @inputs1results0(%[[VAL_5]]) : (!llvm.ptr<i8>) -> ()
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, !llvm.i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr<i8>) -> ()
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
@ -175,44 +132,24 @@ refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
}
// CHECK-LABEL: llvm.func @inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
// CHECK: llvm.return
// CHECK: }
func @inputs1results0(%arg0: !refbackrt.tensor) {
func @inputs1results0(%arg0: memref<*xf32>) {
return
}
// CHECK-LABEL: llvm.func @inputs1results1(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
// CHECK: }
func @inputs1results1(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
return %arg0 : !refbackrt.tensor
func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> {
return %arg0 : memref<*xf32>
}
// CHECK-LABEL: llvm.func @inputs1results2(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.struct<(ptr<i8>, ptr<i8>)> {
// CHECK: %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][0] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: }
func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackrt.tensor) {
return %arg0, %arg0 : !refbackrt.tensor, !refbackrt.tensor
func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) {
return %arg0, %arg0 : memref<*xf32>, memref<*xf32>
}
// -----
// Test emission of compiler runtime functions.
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00")
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK-LABEL: llvm.func @calls_abort_if(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) {
@ -226,29 +163,3 @@ func @calls_abort_if(%arg0: i1) {
refbackrt.abort_if %arg0, "msg"
return
}
// CHECK-LABEL: llvm.func @calls_to_memref(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.return
// CHECK: }
func @calls_to_memref(%arg0: !refbackrt.tensor) {
%0 = refbackrt.to_memref %arg0 : memref<*xf32>
return
}
// CHECK-LABEL: llvm.func @calls_from_memref(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i64,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_3]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_4]][0 : i32] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_4]][1 : i32] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.return %[[VAL_7]] : !llvm.ptr<i8>
// CHECK: }
func @calls_from_memref(%arg0: memref<*xf32>) -> !refbackrt.tensor {
%0 = refbackrt.from_memref %arg0 : memref<*xf32>
return %0 : !refbackrt.tensor
}

View File

@ -20,34 +20,18 @@ func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>)
// Test ABI conversions.
// CHECK-LABEL: func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
// CHECK-LABEL: func @identity(%arg0: memref<*xf32>) -> memref<*xf32>
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// The argument materialization.
// In this test case, these go unused since, as described below, the new
// argument value is seen immediately by the return op for some reason.
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
// TODO: Why do these target materializations not happen in this particular
// test?
// Somehow, the return op rewrite sees the new argument value immediately,
// rather than the result of replaceUsesOfBlockArgument from
// FuncOpSignatureConversion
// Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32>
// Cxxxx-NEXT: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
// Cxxxx-NEXT: return %[[RET]]
// CHECK-NEXT: return %arg0
// CHECK: return %arg0
return %arg0 : memref<?xf32>
}
// -----
// CHECK-LABEL: func @use_of_arg(%arg0: !refbackrt.tensor)
// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>)
func @use_of_arg(%arg0: memref<?xf32>) {
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
%c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?xf32>
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
@ -57,17 +41,15 @@ func @use_of_arg(%arg0: memref<?xf32>) {
// -----
// CHECK-LABEL: func @multiple_blocks(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32>
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
br ^bb1(%arg0: memref<?xf32>)
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
^bb1(%bbarg: memref<?xf32>):
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
// CHECK-NEXT: %[[OUTABIMEMREF:.*]] = refbackrt.from_memref %[[OUTMEMREF]] : memref<*xf32>
// CHECK-NEXT: return %[[OUTABIMEMREF]] : !refbackrt.tensor
// CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32>
return %bbarg : memref<?xf32>
}