Consolidate LLVM definitions of runtime data structures.

This required making module descriptors hold a FuncDescriptor* instead
of a pointer to array of FuncDescriptors as it previously did, which is
innocuous (just requires an llvm.bitcast after the llvm.mlir.addressof).
pull/1/head
Sean Silva 2020-07-10 17:50:55 -07:00
parent e228aa4b11
commit df0d3fcaff
3 changed files with 45 additions and 37 deletions

View File

@ -93,7 +93,6 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
// TODO: verify exists and shape is compatible
}
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [

View File

@ -23,10 +23,38 @@ using mlir::LLVM::LLVMFuncOp;
using mlir::LLVM::LLVMType;
//===----------------------------------------------------------------------===//
// Utilities.
// Descriptor types shared with the runtime.
//
// These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===//
// TODO: Move other descriptor types to here.
// Get the LLVMType for npcomprt::FuncDescriptor.
static LLVMType getFuncDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
return LLVMType::getStructTy(llvmDialect,
{
// Name length.
LLVMType::getIntNTy(llvmDialect, 32),
// Name chars.
LLVMType::getInt8PtrTy(llvmDialect),
// Type-erased function pointer.
LLVMType::getInt8PtrTy(llvmDialect),
// Number of inputs.
LLVMType::getIntNTy(llvmDialect, 32),
// Number of outputs.
LLVMType::getIntNTy(llvmDialect, 32),
});
}
// Get the LLVMType for npcomprt::ModuleDescriptor.
static LLVMType getModuleDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
return LLVMType::getStructTy(
llvmDialect, {
// std::int32_t numFuncDescriptors;
LLVMType::getIntNTy(llvmDialect, 32),
// FuncDescriptor *functionDescriptors;
getFuncDescriptorTy(llvmDialect).getPointerTo(),
});
}
// Get the LLVMType for npcomprt::GlobalDescriptor.
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
@ -374,19 +402,7 @@ createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas,
}
// This must match FuncDescriptor in the runtime.
auto funcDescriptorTy = LLVMType::getStructTy(
llvmDialect, {
// Name length.
llvmI32Ty,
// Name chars.
LLVMType::getInt8PtrTy(llvmDialect),
// Type-erased function pointer.
LLVMType::getInt8PtrTy(llvmDialect),
// Number of inputs.
llvmI32Ty,
// Number of outputs.
llvmI32Ty,
});
auto funcDescriptorTy = getFuncDescriptorTy(llvmDialect);
auto funcDescriptorArrayTy =
LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size());
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
@ -458,11 +474,7 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32);
auto moduleDescriptorTy = LLVMType::getStructTy(
llvmDialect, {
llvmI32Ty,
funcDescriptorArray.getType().getPointerTo(),
});
auto moduleDescriptorTy = getModuleDescriptorTy(llvmDialect);
// TODO: Ideally this symbol name would somehow be related to the module
// name, if we could consistently assume we had one.
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
@ -490,8 +502,13 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
builder.getI32IntegerAttr(
funcDescriptorArray.getType().getArrayNumElements())),
{0});
updateDescriptor(builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray),
{1});
auto funcDecriptorArrayAddress =
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
loc, getFuncDescriptorTy(llvmDialect).getPointerTo(),
funcDecriptorArrayAddress);
updateDescriptor(rawFuncDescriptorPtr, {1});
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
return moduleDescriptorGlobal;

View File

@ -34,13 +34,14 @@
// CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]">
// CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*">
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> to !llvm<"{ i32, i8*, i8*, i32, i32 }*">
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: llvm.return %[[VAL_5]] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: }
npcomprt.module_metadata {
@ -150,15 +151,6 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]">
// CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]*">
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: }
npcomprt.module_metadata {
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}