Consolidate LLVM definitions of runtime data structures.

This required making module descriptors hold a FuncDescriptor* instead
of a pointer to array of FuncDescriptors as it previously did, which is
innocuous (just requires an llvm.bitcast after the llvm.mlir.addressof).
pull/1/head
Sean Silva 2020-07-10 17:50:55 -07:00
parent e228aa4b11
commit df0d3fcaff
3 changed files with 45 additions and 37 deletions

View File

@ -93,7 +93,6 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
let results = (outs AnyUnrankedMemRef:$memref); let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)"; let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);"; let verifier = "return ::verify$cppClass(*this);";
// TODO: verify exists and shape is compatible
} }
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [ def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [

View File

@ -23,10 +23,38 @@ using mlir::LLVM::LLVMFuncOp;
using mlir::LLVM::LLVMType; using mlir::LLVM::LLVMType;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Utilities. // Descriptor types shared with the runtime.
//
// These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO: Move other descriptor types to here. // Get the LLVMType for npcomprt::FuncDescriptor.
static LLVMType getFuncDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
return LLVMType::getStructTy(llvmDialect,
{
// Name length.
LLVMType::getIntNTy(llvmDialect, 32),
// Name chars.
LLVMType::getInt8PtrTy(llvmDialect),
// Type-erased function pointer.
LLVMType::getInt8PtrTy(llvmDialect),
// Number of inputs.
LLVMType::getIntNTy(llvmDialect, 32),
// Number of outputs.
LLVMType::getIntNTy(llvmDialect, 32),
});
}
// Get the LLVMType for npcomprt::ModuleDescriptor.
static LLVMType getModuleDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
return LLVMType::getStructTy(
llvmDialect, {
// std::int32_t numFuncDescriptors;
LLVMType::getIntNTy(llvmDialect, 32),
// FuncDescriptor *functionDescriptors;
getFuncDescriptorTy(llvmDialect).getPointerTo(),
});
}
// Get the LLVMType for npcomprt::GlobalDescriptor. // Get the LLVMType for npcomprt::GlobalDescriptor.
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) { static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
@ -374,19 +402,7 @@ createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas,
} }
// This must match FuncDescriptor in the runtime. // This must match FuncDescriptor in the runtime.
auto funcDescriptorTy = LLVMType::getStructTy( auto funcDescriptorTy = getFuncDescriptorTy(llvmDialect);
llvmDialect, {
// Name length.
llvmI32Ty,
// Name chars.
LLVMType::getInt8PtrTy(llvmDialect),
// Type-erased function pointer.
LLVMType::getInt8PtrTy(llvmDialect),
// Number of inputs.
llvmI32Ty,
// Number of outputs.
llvmI32Ty,
});
auto funcDescriptorArrayTy = auto funcDescriptorArrayTy =
LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size()); LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size());
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>( auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
@ -458,11 +474,7 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
auto *llvmDialect = auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32); auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32);
auto moduleDescriptorTy = LLVMType::getStructTy( auto moduleDescriptorTy = getModuleDescriptorTy(llvmDialect);
llvmDialect, {
llvmI32Ty,
funcDescriptorArray.getType().getPointerTo(),
});
// TODO: Ideally this symbol name would somehow be related to the module // TODO: Ideally this symbol name would somehow be related to the module
// name, if we could consistently assume we had one. // name, if we could consistently assume we had one.
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which // TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
@ -490,8 +502,13 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
builder.getI32IntegerAttr( builder.getI32IntegerAttr(
funcDescriptorArray.getType().getArrayNumElements())), funcDescriptorArray.getType().getArrayNumElements())),
{0}); {0});
updateDescriptor(builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray),
{1}); auto funcDecriptorArrayAddress =
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
loc, getFuncDescriptorTy(llvmDialect).getPointerTo(),
funcDecriptorArrayAddress);
updateDescriptor(rawFuncDescriptorPtr, {1});
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor); builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
return moduleDescriptorGlobal; return moduleDescriptorGlobal;

View File

@ -34,13 +34,14 @@
// CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]"> // CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]">
// CHECK: } // CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> { // CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> // CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> // CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> // CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*">
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> // CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> to !llvm<"{ i32, i8*, i8*, i32, i32 }*">
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> // CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: llvm.return %[[VAL_5]] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
// CHECK: } // CHECK: }
npcomprt.module_metadata { npcomprt.module_metadata {
@ -150,15 +151,6 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]"> // CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]">
// CHECK: } // CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]*">
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
// CHECK: }
npcomprt.module_metadata { npcomprt.module_metadata {
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32} npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32} npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}