diff --git a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td index 8633a6ed9..5a2232efb 100644 --- a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td +++ b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td @@ -23,14 +23,4 @@ lowered to the llvm dialect. }]; } -def Refbackrt_Tensor - : DialectType< - Refbackrt_Dialect, - CPred<"$_self.isa<::mlir::NPCOMP::refbackrt::TensorType>()">, - "refbackrt.tensor">, - BuildableType< - "$_builder.getType<::mlir::NPCOMP::refbackrt::TensorType>()"> { - let typeDescription = [{The runtime type that represents a buffer.}]; -} - #endif // #ifndef REFBACKRT_BASE diff --git a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td index 732d6e059..4600adf0f 100644 --- a/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td +++ b/include/npcomp/Dialect/Refbackrt/IR/RefbackrtOps.td @@ -16,26 +16,6 @@ class Refbackrt_Op traits = []> : Op { } -def Refbackrt_ToMemrefOp : Refbackrt_Op<"to_memref"> { - let summary = "Gets a memref descriptor from a tensor"; - let description = [{ - Gets a memref descriptor from a tensor. - }]; - let arguments = (ins Refbackrt_Tensor:$tensor); - let results = (outs AnyUnrankedMemRef:$memref); - let assemblyFormat = "$tensor attr-dict `:` type($memref)"; -} - -def Refbackrt_FromMemrefOp : Refbackrt_Op<"from_memref"> { - let summary = "Converts a memref descriptor to a tensor"; - let description = [{ - Copies the data from a memref into a new tensor. - }]; - let arguments = (ins AnyUnrankedMemRef:$memref); - let results = (outs Refbackrt_Tensor:$tensor); - let assemblyFormat = "$memref attr-dict `:` type($memref)"; -} - def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> { let summary = "Aborts if the predicate is true"; let description = [{ diff --git a/include/npcomp/RefBackend/Runtime/UserAPI.h b/include/npcomp/RefBackend/Runtime/UserAPI.h index 05ae6cb64..8ba68cb1e 100644 --- a/include/npcomp/RefBackend/Runtime/UserAPI.h +++ b/include/npcomp/RefBackend/Runtime/UserAPI.h @@ -70,6 +70,8 @@ public: return ret; } + int debugGetRefCount() { return ptr->refCount; } + private: static void incref(T *ptr) { if (!ptr) diff --git a/lib/Dialect/Refbackrt/IR/RefbackrtDialect.cpp b/lib/Dialect/Refbackrt/IR/RefbackrtDialect.cpp index 12c4d4063..7d534da46 100644 --- a/lib/Dialect/Refbackrt/IR/RefbackrtDialect.cpp +++ b/lib/Dialect/Refbackrt/IR/RefbackrtDialect.cpp @@ -21,23 +21,3 @@ void RefbackrtDialect::initialize() { >(); addTypes(); } - -Type RefbackrtDialect::parseType(DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - - if (keyword == "tensor") - return TensorType::get(getContext()); - - parser.emitError(parser.getNameLoc(), "unknown type in 'refbackrt' dialect: ") - << keyword; - return Type(); -} - -void RefbackrtDialect::printType(Type type, DialectAsmPrinter &os) const { - TypeSwitch(type) - .Case([&](Type) { os << "tensor"; }) - .Default( - [&](Type) { llvm_unreachable("unexpected 'refbackrt' type kind"); }); -} diff --git a/lib/RefBackend/LowerToLLVM.cpp b/lib/RefBackend/LowerToLLVM.cpp index 074439f3b..6bfaf2b5d 100644 --- a/lib/RefBackend/LowerToLLVM.cpp +++ b/lib/RefBackend/LowerToLLVM.cpp @@ -77,35 +77,6 @@ public: }; } // namespace -namespace { -// FromMemrefOp requires special handling so that the unranked memref descriptor -// gets passed as two separate arguments instead of as a struct. -class FromMemrefOpCompilerRuntimeLowering - : public OpConversionPattern { -public: - FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) - : OpConversionPattern(backingFunc.getContext()), - backingFunc(backingFunc) {} - LogicalResult - matchAndRewrite(refbackrt::FromMemrefOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto structVal = operands[0]; - Value rank = rewriter.create( - op.getLoc(), - structVal.getType().cast().getStructElementType(0), structVal, - rewriter.getI32ArrayAttr({0})); - Value descriptorPtr = rewriter.create( - op.getLoc(), - structVal.getType().cast().getStructElementType(1), structVal, - rewriter.getI32ArrayAttr({1})); - rewriter.replaceOpWithNewOp( - op, backingFunc, ValueRange({rank, descriptorPtr})); - return success(); - } - LLVM::LLVMFuncOp backingFunc; -}; -} // namespace - static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg, OpBuilder &builder, Location loc) { // TODO: Deduplicate strings. @@ -188,35 +159,6 @@ static void populateCompilerRuntimePatterns(ModuleOp module, "abort_if", abortIfFuncTy, builder, module.getLoc()); patterns.insert(abortIfFunc); } - - auto convertFunctionType = [&](FunctionType type) { - TypeConverter::SignatureConversion conversion(type.getNumInputs()); - return typeConverter.convertFunctionSignature(type, /*isVariadic=*/false, - conversion); - }; - - { - auto mlirFunctionType = builder.getFunctionType( - {builder.getType()}, - {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}); - LLVMType funcTy = convertFunctionType(mlirFunctionType); - LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl( - "to_memref", funcTy, builder, module.getLoc()); - patterns.insert>( - toMemrefFunc); - } - - { - // TODO: Pass in an element type enum, since the unranked memref descriptor - // doesn't know its own dtype. - auto mlirFunctionType = builder.getFunctionType( - {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}, - {builder.getType()}); - LLVMType funcTy = convertFunctionType(mlirFunctionType); - LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl( - "from_memref", funcTy, builder, module.getLoc()); - patterns.insert(fromMemrefFunc); - } } //===----------------------------------------------------------------------===// @@ -390,9 +332,12 @@ static Value getTypedAddressFromVoidStarStar(Value voidStarStar, int32_t index, Value ci = builder.create( loc, LLVMType::getIntNTy(builder.getContext(), 32), builder.getI32IntegerAttr(index)); - auto inputPtr = builder.create( - loc, LLVMType::getInt8PtrTy(builder.getContext()), voidStarStar, - ValueRange(ci)); + + // Do `voidStarStar[i]` as a gep + load. + auto inputPtrAddr = builder.create( + loc, LLVMType::getInt8PtrTy(builder.getContext()).getPointerTo(), + voidStarStar, ValueRange(ci)); + auto inputPtr = builder.create(loc, inputPtrAddr); return builder.create(loc, ty.getPointerTo(), inputPtr); } @@ -409,6 +354,21 @@ static SmallVector loadCallArgs(Value inputsPtrPtr, LLVMType funcTy, return callArgs; } +static LLVM::LLVMType getUnrankedMemrefDescriptorType(MLIRContext *context) { + LLVMTypeConverter converter(context); + // LLVMTypeConverter doesn't directly expose the struct type used to represent + // unranked memrefs on ABI boundaries. To get that type, we convert + // an unranked memref type and see what it produces. + // + // An unranked memref is just a size_t for the rank and an void* pointer to + // descriptor, so the choice of element type here is arbitrary -- it all + // converts to the same thing. + return converter + .convertType(UnrankedMemRefType::get(Float32Type::get(context), + /*memorySpace=*/0)) + .cast(); +} + // Writes out the logical results of the wrapper function through the void** // passed on the ABI boundary. Because LLVM (and hence llvm.func) // only supports a single return type (or void/no results), the logic here needs @@ -424,13 +384,15 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr, Value result = callToWrapped.getResult(0); auto ty = result.getType().cast(); // 1 logical result. - if (!ty.isStructTy()) { + if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) { Value addr = getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc); builder.create(loc, result, addr); return; } - // >=2 logical results. + assert(ty.isStructTy() && "must be a multi-result packed struct!"); + // >=2 logical results. The convention linked above will create a struct + // wrapping. for (int i = 0, e = ty.getStructNumElements(); i < e; i++) { auto elementTy = ty.getStructElementType(i); Value addr = getTypedAddressFromVoidStarStar(resultsPtrPtr, i, elementTy, @@ -492,11 +454,6 @@ class LowerToLLVM : public LowerToLLVMBase { LLVMTypeConverter converter(context); - // refbackrt::TensorType is passed as a `void*` in the ABI. - converter.addConversion([&](refbackrt::TensorType type) { - return LLVMType::getInt8PtrTy(context); - }); - OwningRewritePatternList patterns; LLVMConversionTarget target(*context); target.addDynamicallyLegalOp( diff --git a/lib/RefBackend/LowerToRefbackrtABI.cpp b/lib/RefBackend/LowerToRefbackrtABI.cpp index 4d28c4631..6d9005b36 100644 --- a/lib/RefBackend/LowerToRefbackrtABI.cpp +++ b/lib/RefBackend/LowerToRefbackrtABI.cpp @@ -97,7 +97,8 @@ public: } // namespace namespace { -// At ABI bondaries, use !refbackrt.tensor instead of memref. +// At ABI boundaries, convert all memrefs to unranked memrefs so that they have +// a fixed ABI. class FuncOpSignatureConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -128,9 +129,7 @@ public: for (auto newAndOldArg : llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) { std::tie(newArg, oldArg) = newAndOldArg; - auto abiMemref = rewriter.create( - op.getLoc(), getABIMemrefType(oldArg.getType()), newArg); - auto memref = rewriter.create(op.getLoc(), abiMemref, + auto memref = rewriter.create(op.getLoc(), newArg, oldArg.getType()); rewriter.replaceUsesOfBlockArgument(oldArg, memref); } @@ -141,7 +140,7 @@ public: } // namespace namespace { -// At the return ABI boundaries, convert to !refbackrt.tensor type. +// At the return ABI boundaries, convert to the ABI type. // This pattern is needed to trigger the type conversion mechanics to do a // target materialization. class RewriteReturnOp : public OpConversionPattern { @@ -161,16 +160,14 @@ static LogicalResult doDialectConversion(ModuleOp module) { TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion([](MemRefType type) { - return refbackrt::TensorType::get(type.getContext()); - }); + typeConverter.addConversion( + [](MemRefType type) { return getABIMemrefType(type); }); typeConverter.addTargetMaterialization( - [](OpBuilder &builder, refbackrt::TensorType type, ValueRange inputs, + [](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); - auto abiMemref = builder.create( + return builder.create( loc, inputs[0], getABIMemrefType(inputs[0].getType())); - return builder.create(loc, type, abiMemref); }); OwningRewritePatternList patterns; diff --git a/lib/RefBackend/Runtime/CompilerRuntime.cpp b/lib/RefBackend/Runtime/CompilerRuntime.cpp index 3bfae89b0..8e32193fd 100644 --- a/lib/RefBackend/Runtime/CompilerRuntime.cpp +++ b/lib/RefBackend/Runtime/CompilerRuntime.cpp @@ -26,91 +26,3 @@ extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) { std::exit(1); } } - -namespace { -// These definitions are based on the ones in -// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in -// sync. -// -// Those definitions are flawed though because they are overly templated. -struct MemrefDescriptor { - void *allocatedPtr; - void *dataPtr; - std::int64_t offset; - // Tail-allocated int64_t sizes followed by strides. - MutableArrayRef getSizes(int assumedRank) { - auto *tail = reinterpret_cast(this + 1); - return MutableArrayRef(tail, assumedRank); - } - MutableArrayRef getStrides(int assumedRank) { - auto *tail = reinterpret_cast(this + 1); - return MutableArrayRef(tail + assumedRank, assumedRank); - } - - // Returns a malloc-allocated MemrefDescriptor with the specified extents and - // default striding. - static MemrefDescriptor *create(ArrayRef extents, void *data); -}; - -struct UnrankedMemref { - int64_t rank; - MemrefDescriptor *descriptor; -}; - -MemrefDescriptor *MemrefDescriptor::create(ArrayRef extents, - void *data) { - auto rank = extents.size(); - auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank; - auto *descriptor = static_cast(std::malloc(allocSize)); - descriptor->allocatedPtr = data; - descriptor->dataPtr = data; - descriptor->offset = 0; - // Iterate in reverse, copying the dimension sizes (i.e. extents) and - // calculating the strides for a standard dense layout. - std::int64_t stride = 1; - for (int i = 0, e = rank; i < e; i++) { - auto revIdx = e - i - 1; - descriptor->getSizes(rank)[revIdx] = extents[revIdx]; - descriptor->getStrides(rank)[revIdx] = stride; - stride *= extents[revIdx]; - } - return descriptor; -} - -std::int32_t getNumElements(MemrefDescriptor *descriptor, int assumedRank) { - if (assumedRank == 0) - return 1; - return descriptor->getSizes(assumedRank)[0] * - descriptor->getStrides(assumedRank)[0]; -} - -} // namespace - -extern "C" UnrankedMemref __npcomp_compiler_rt_to_memref(Tensor *tensor) { - auto byteSize = tensor->getDataByteSize(); - void *data = std::malloc(byteSize); - std::memcpy(data, tensor->getData(), byteSize); - auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data); - return UnrankedMemref{tensor->getRank(), descriptor}; -} - -extern "C" Tensor * -__npcomp_compiler_rt_from_memref(std::int64_t rank, - MemrefDescriptor *descriptor) { - auto numElements = getNumElements(descriptor, rank); - // TODO: Have the compiler pass this as an argument. - auto elementType = ElementType::F32; - - auto byteSize = getElementTypeByteSize(elementType) * numElements; - void *data = std::malloc(byteSize); - std::memcpy(data, descriptor->dataPtr, byteSize); - auto extents64 = descriptor->getSizes(rank); - - // Launder from std::int64_t to std::int32_t. - constexpr int kMaxRank = 20; - std::array extents32Buf; - for (int i = 0, e = extents64.size(); i < e; i++) - extents32Buf[i] = extents64[i]; - return Tensor::createRaw(ArrayRef(extents32Buf.data(), rank), - elementType, data); -} diff --git a/lib/RefBackend/Runtime/Runtime.cpp b/lib/RefBackend/Runtime/Runtime.cpp index c8bc03990..c397f201c 100644 --- a/lib/RefBackend/Runtime/Runtime.cpp +++ b/lib/RefBackend/Runtime/Runtime.cpp @@ -17,6 +17,91 @@ using namespace refbackrt; +//===----------------------------------------------------------------------===// +// Memref descriptors for interacting with MLIR codegenerated code. +//===----------------------------------------------------------------------===// + +namespace { +// These definitions are based on the ones in +// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in +// sync. +// +// Those definitions are flawed though because they are overly templated. +struct MemrefDescriptor { + void *allocatedPtr; + void *dataPtr; + std::int64_t offset; + // Tail-allocated int64_t sizes followed by strides. + MutableArrayRef getSizes(int assumedRank) { + auto *tail = reinterpret_cast(this + 1); + return MutableArrayRef(tail, assumedRank); + } + MutableArrayRef getStrides(int assumedRank) { + auto *tail = reinterpret_cast(this + 1); + return MutableArrayRef(tail + assumedRank, assumedRank); + } + + // Returns a malloc-allocated MemrefDescriptor with the specified extents and + // default striding. + static MemrefDescriptor *create(ArrayRef extents, void *data); + + // Returns the number of elements in this MemrefDescriptor, assuming this + // descriptor has rank `assumedRank`. + std::int32_t getNumElements(int assumedRank) { + if (assumedRank == 0) + return 1; + return getSizes(assumedRank)[0] * getStrides(assumedRank)[0]; + } +}; +} // namespace + +namespace { +struct UnrankedMemref { + int64_t rank; + MemrefDescriptor *descriptor; +}; +} // namespace + +MemrefDescriptor *MemrefDescriptor::create(ArrayRef extents, + void *data) { + auto rank = extents.size(); + auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank; + auto *descriptor = static_cast(std::malloc(allocSize)); + descriptor->allocatedPtr = data; + descriptor->dataPtr = data; + descriptor->offset = 0; + // Iterate in reverse, copying the dimension sizes (i.e. extents) and + // calculating the strides for a standard dense layout. + std::int64_t stride = 1; + for (int i = 0, e = rank; i < e; i++) { + auto revIdx = e - i - 1; + descriptor->getSizes(rank)[revIdx] = extents[revIdx]; + descriptor->getStrides(rank)[revIdx] = stride; + stride *= extents[revIdx]; + } + return descriptor; +} + +static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) { + auto byteSize = tensor->getDataByteSize(); + void *data = std::malloc(byteSize); + std::memcpy(data, tensor->getData(), byteSize); + auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data); + return UnrankedMemref{tensor->getRank(), descriptor}; +} + +static Tensor *convertUnrankedMemrefToRefbackrtTensor( + std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) { + // Launder from std::int64_t to std::int32_t. + auto extents64 = descriptor->getSizes(rank); + constexpr int kMaxRank = 20; + std::array extents32Buf; + for (int i = 0, e = extents64.size(); i < e; i++) + extents32Buf[i] = extents64[i]; + return Tensor::createRaw(ArrayRef(extents32Buf.data(), rank), + elementType, descriptor->dataPtr); +} + //===----------------------------------------------------------------------===// // Tensor //===----------------------------------------------------------------------===// @@ -93,25 +178,71 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor, assert(descriptor && "unknown function name"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity"); - std::array inputTensorPtrs; - std::array outputTensorPtrs; - std::array packedInputs; + std::array inputUnrankedMemrefs; + std::array outputUnrankedMemrefs; + std::array packedInputs; std::array packedOutputs; - for (int i = 0, e = inputs.size(); i < e; i++) - inputTensorPtrs[i] = inputs[i].get(); - for (int i = 0, e = inputs.size(); i < e; i++) - packedInputs[i] = ToVoidPtr(inputTensorPtrs[i]); + for (int i = 0, e = inputs.size(); i < e; i++) { + inputUnrankedMemrefs[i] = + convertRefbackrtTensorToUnrankedMemref(inputs[i].get()); + } + for (int i = 0, e = inputs.size(); i < e; i++) { + packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank); + packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor); + } + for (int i = 0, e = outputs.size(); i < e; i++) { + packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]); + } descriptor->functionPtr(packedInputs.data(), packedOutputs.data()); - for (int i = 0, e = outputs.size(); i < e; i++) - outputTensorPtrs[i] = static_cast(packedOutputs[i]); - // TODO: Actually manage refcounts inside the compiler. - // Right now, we only pass around refbackrt.tensor's in trivial ways on ABI - // boundaries, so the following contract of the compiler-generated code works: - // - input tensors are never retained or released - // - output tensors always have refcount 0. Hence the next line here is - // actually essential because it increments the refcounts so they are nonzero. - for (int i = 0, e = outputs.size(); i < e; i++) - outputs[i] = Ref(outputTensorPtrs[i]); + for (int i = 0, e = outputs.size(); i < e; i++) { + // TODO: Have compiler emit the element type in the metadata. + auto elementType = ElementType::F32; + Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor( + outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor, + elementType); + outputs[i] = Ref(tensor); + } + for (int i = 0, e = outputs.size(); i < e; i++) { + void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr; + // Multiple returned memrefs can point into the same underlying + // malloc allocation. Do a linear scan to see if any of the previously + // deallocated buffers already freed this pointer. + bool bufferNeedsFreeing = true; + for (int j = 0; j < i; j++) { + if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) + bufferNeedsFreeing = false; + } + if (!bufferNeedsFreeing) + std::free(allocatedPtr); + } + for (int i = 0, e = inputs.size(); i < e; i++) { + void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr; + bool bufferNeedsFreeing = true; + for (int j = 0, je = outputs.size(); j < je; j++) { + if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) + bufferNeedsFreeing = false; + } + // HACK: The returned memref can point into statically allocated memory that + // we can't pass to `free`, such as the result of lowering a tensor-valued + // `std.constant` to `std.global_memref`. The LLVM lowering of + // std.global_memref sets the allocated pointer to the magic value + // 0xDEADBEEF, which we sniff for here. This is yet another strong signal + // that memref is really not the right abstraction for ABI's. + if (reinterpret_cast(allocatedPtr) == 0xDEADBEEF) + bufferNeedsFreeing = false; + if (!bufferNeedsFreeing) + std::free(allocatedPtr); + } + + for (int i = 0, e = outputs.size(); i < e; i++) { + // The LLVM lowering guarantees that each returned unranked memref + // descriptor is separately malloc'ed, so no need to do anything special + // like we had to do for the allocatedPtr's. + std::free(outputUnrankedMemrefs[i].descriptor); + } + for (int i = 0, e = inputs.size(); i < e; i++) { + std::free(inputUnrankedMemrefs[i].descriptor); + } } LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor, diff --git a/test/Dialect/Refbackrt/ops.mlir b/test/Dialect/Refbackrt/ops.mlir index bf67058ca..f1e6b5925 100644 --- a/test/Dialect/Refbackrt/ops.mlir +++ b/test/Dialect/Refbackrt/ops.mlir @@ -6,8 +6,6 @@ refbackrt.module_metadata { refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32} } -// CHECK-LABEL: func @f -// CHECK-SAME: !refbackrt.tensor -func @f(%arg0: !refbackrt.tensor) { +func @f(%arg0: memref<*xf32>) { return } diff --git a/test/RefBackend/lower-to-llvm.mlir b/test/RefBackend/lower-to-llvm.mlir index cdde32cec..8f70fda0b 100644 --- a/test/RefBackend/lower-to-llvm.mlir +++ b/test/RefBackend/lower-to-llvm.mlir @@ -1,118 +1,75 @@ // RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail -// CHECK-LABEL: llvm.func @__refbackrt_wrapper_identity( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { -// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr> -// CHECK: %[[VAL_6:.*]] = llvm.call @identity(%[[VAL_5]]) : (!llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr to !llvm.ptr> -// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr> -// CHECK: llvm.return -// CHECK: } -// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) -// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_identity("identity") - -// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(8 : i32) : !llvm.i32 -// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr> -// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr>, !llvm.i32, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_identity : !llvm.ptr>, ptr>)>> -// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr>, ptr>)>> to !llvm.ptr -// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: llvm.return %[[VAL_13]] : !llvm.array<1 x struct<(i32, ptr, ptr, i32, i32)>> -// CHECK: } - -// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr, ptr, i32, i32)>>)> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr, ptr, i32, i32)>>)> -// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr, ptr, i32, i32)>>)> -// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr, ptr, i32, i32)>>> -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr, ptr, i32, i32)>>> to !llvm.ptr, ptr, i32, i32)>> -// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr, ptr, i32, i32)>>)> -// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr, ptr, i32, i32)>>)> -// CHECK: } - -refbackrt.module_metadata { - refbackrt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32} -} - - -// CHECK-LABEL: llvm.func @identity( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { -// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr -// CHECK: } -func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor { - return %arg0 : !refbackrt.tensor -} - -// ----- - // Test input/output arg marshaling. // CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr> -// CHECK: %[[VAL_6:.*]] = llvm.call @inputs1results2(%[[VAL_5]]) : (!llvm.ptr) -> !llvm.struct<(ptr, ptr)> -// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_6]][0 : i32] : !llvm.struct<(ptr, ptr)> -// CHECK: llvm.store %[[VAL_10]], %[[VAL_9]] : !llvm.ptr> -// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_11]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_13:.*]] = llvm.bitcast %[[VAL_12]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_6]][1 : i32] : !llvm.struct<(ptr, ptr)> -// CHECK: llvm.store %[[VAL_14]], %[[VAL_13]] : !llvm.ptr> +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr> +// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr> +// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr> +// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr> +// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr) -> !llvm.struct<(struct<(i64, ptr)>, struct<(i64, ptr)>)> +// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr> +// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr to !llvm.ptr)>> +// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr)>, struct<(i64, ptr)>)> +// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr)>> +// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr> +// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr to !llvm.ptr)>> +// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr)>, struct<(i64, ptr)>)> +// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr)>> // CHECK: llvm.return // CHECK: } // CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr> -// CHECK: %[[VAL_6:.*]] = llvm.call @inputs1results1(%[[VAL_5]]) : (!llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_9:.*]] = llvm.bitcast %[[VAL_8]] : !llvm.ptr to !llvm.ptr> -// CHECK: llvm.store %[[VAL_6]], %[[VAL_9]] : !llvm.ptr> +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr> +// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr> +// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr> +// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr> +// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr) -> !llvm.struct<(i64, ptr)> +// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr> +// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr to !llvm.ptr)>> +// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr)>> // CHECK: llvm.return // CHECK: } -// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { +/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr>) { // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 -// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr -// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr to !llvm.ptr> -// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr> -// CHECK: llvm.call @inputs1results0(%[[VAL_5]]) : (!llvm.ptr) -> () +// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr> +// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr> +// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr> +// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr> +// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (!llvm.i64, !llvm.ptr) -> () // CHECK: llvm.return // CHECK: } // CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) -// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0") // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1") // CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2") @@ -175,44 +132,24 @@ refbackrt.module_metadata { refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32} } - - -// CHECK-LABEL: llvm.func @inputs1results0( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) { -// CHECK: llvm.return -// CHECK: } -func @inputs1results0(%arg0: !refbackrt.tensor) { +func @inputs1results0(%arg0: memref<*xf32>) { return } -// CHECK-LABEL: llvm.func @inputs1results1( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { -// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr -// CHECK: } -func @inputs1results1(%arg0: !refbackrt.tensor) -> !refbackrt.tensor { - return %arg0 : !refbackrt.tensor +func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> { + return %arg0 : memref<*xf32> } -// CHECK-LABEL: llvm.func @inputs1results2( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.struct<(ptr, ptr)> { -// CHECK: %[[VAL_1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr)> -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_1]][0] : !llvm.struct<(ptr, ptr)> -// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr, ptr)> -// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr, ptr)> -// CHECK: } -func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackrt.tensor) { - return %arg0, %arg0 : !refbackrt.tensor, !refbackrt.tensor +func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) { + return %arg0, %arg0 : memref<*xf32>, memref<*xf32> } - // ----- // Test emission of compiler runtime functions. // CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00") // CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr) -// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr) -> !llvm.ptr // CHECK-LABEL: llvm.func @calls_abort_if( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) { @@ -226,29 +163,3 @@ func @calls_abort_if(%arg0: i1) { refbackrt.abort_if %arg0, "msg" return } - -// CHECK-LABEL: llvm.func @calls_to_memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) { -// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr) -> !llvm.struct<(i64, ptr)> -// CHECK: llvm.return -// CHECK: } -func @calls_to_memref(%arg0: !refbackrt.tensor) { - %0 = refbackrt.to_memref %arg0 : memref<*xf32> - return -} - -// CHECK-LABEL: llvm.func @calls_from_memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i64, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr { -// CHECK: %[[VAL_2:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_3]][1] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_4]][0 : i32] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_4]][1 : i32] : !llvm.struct<(i64, ptr)> -// CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr) -> !llvm.ptr -// CHECK: llvm.return %[[VAL_7]] : !llvm.ptr -// CHECK: } -func @calls_from_memref(%arg0: memref<*xf32>) -> !refbackrt.tensor { - %0 = refbackrt.from_memref %arg0 : memref<*xf32> - return %0 : !refbackrt.tensor -} diff --git a/test/RefBackend/lower-to-refbackrt-abi.mlir b/test/RefBackend/lower-to-refbackrt-abi.mlir index 81071da7e..ee6c36a4e 100644 --- a/test/RefBackend/lower-to-refbackrt-abi.mlir +++ b/test/RefBackend/lower-to-refbackrt-abi.mlir @@ -20,34 +20,18 @@ func @f_1input_2outputs(%arg0: memref) -> (memref, memref) // Test ABI conversions. -// CHECK-LABEL: func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor +// CHECK-LABEL: func @identity(%arg0: memref<*xf32>) -> memref<*xf32> func @identity(%arg0: memref) -> memref { - // The argument materialization. - // In this test case, these go unused since, as described below, the new - // argument value is seen immediately by the return op for some reason. - // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32> - // CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref - - // TODO: Why do these target materializations not happen in this particular - // test? - // Somehow, the return op rewrite sees the new argument value immediately, - // rather than the result of replaceUsesOfBlockArgument from - // FuncOpSignatureConversion - // Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref to memref<*xf32> - // Cxxxx-NEXT: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32> - // Cxxxx-NEXT: return %[[RET]] - - // CHECK-NEXT: return %arg0 + // CHECK: return %arg0 return %arg0 : memref } // ----- -// CHECK-LABEL: func @use_of_arg(%arg0: !refbackrt.tensor) +// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>) func @use_of_arg(%arg0: memref) { - // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32> - // CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref + // CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref %c0 = constant 0 : index %0 = dim %arg0, %c0 : memref // CHECK-NEXT: %[[C0:.*]] = constant 0 : index @@ -57,17 +41,15 @@ func @use_of_arg(%arg0: memref) { // ----- -// CHECK-LABEL: func @multiple_blocks(%arg0: !refbackrt.tensor) -> !refbackrt.tensor +// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32> func @multiple_blocks(%arg0: memref) -> memref { - // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32> - // CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref + // CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %arg0 : memref<*xf32> to memref // CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref) br ^bb1(%arg0: memref) // CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref): ^bb1(%bbarg: memref): // CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref to memref<*xf32> - // CHECK-NEXT: %[[OUTABIMEMREF:.*]] = refbackrt.from_memref %[[OUTMEMREF]] : memref<*xf32> - // CHECK-NEXT: return %[[OUTABIMEMREF]] : !refbackrt.tensor + // CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32> return %bbarg : memref } diff --git a/test/npcomp-run-mlir/constant-add-scalar.mlir b/test/npcomp-run-mlir/constant-add-scalar.mlir index 2bd9311b3..f658c7f64 100644 --- a/test/npcomp-run-mlir/constant-add-scalar.mlir +++ b/test/npcomp-run-mlir/constant-add-scalar.mlir @@ -9,4 +9,4 @@ func @constant_add_scalar(%arg0: tensor) -> tensor { %0 = constant dense<1.0> : tensor %1 = tcf.add %arg0, %0 : (tensor, tensor) -> tensor return %1 : tensor -} \ No newline at end of file +}