From df0d3fcaff6d023cf1d83ac39837f97f5b5ee450 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Fri, 10 Jul 2020 17:50:55 -0700 Subject: [PATCH] Consolidate LLVM definitions of runtime data structures. This required making module descriptors hold a FuncDescriptor* instead of a pointer to array of FuncDescriptors as it previously did, which is innocuous (just requires an llvm.bitcast after the llvm.mlir.addressof). --- .../npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td | 1 - lib/E2E/LowerToLLVM.cpp | 61 ++++++++++++------- test/E2E/lower-to-llvm.mlir | 20 ++---- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td index 880fd742f..b1d3fac9c 100644 --- a/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td +++ b/include/npcomp/Dialect/Npcomprt/IR/NpcomprtOps.td @@ -93,7 +93,6 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> { let results = (outs AnyUnrankedMemRef:$memref); let assemblyFormat = "$global attr-dict `:` type($memref)"; let verifier = "return ::verify$cppClass(*this);"; - // TODO: verify exists and shape is compatible } def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [ diff --git a/lib/E2E/LowerToLLVM.cpp b/lib/E2E/LowerToLLVM.cpp index 9aa715d8b..1ebbaf154 100644 --- a/lib/E2E/LowerToLLVM.cpp +++ b/lib/E2E/LowerToLLVM.cpp @@ -23,10 +23,38 @@ using mlir::LLVM::LLVMFuncOp; using mlir::LLVM::LLVMType; //===----------------------------------------------------------------------===// -// Utilities. +// Descriptor types shared with the runtime. +// +// These correspond to the types in CompilerDataStructures.h //===----------------------------------------------------------------------===// -// TODO: Move other descriptor types to here. +// Get the LLVMType for npcomprt::FuncDescriptor. +static LLVMType getFuncDescriptorTy(LLVM::LLVMDialect *llvmDialect) { + return LLVMType::getStructTy(llvmDialect, + { + // Name length. + LLVMType::getIntNTy(llvmDialect, 32), + // Name chars. + LLVMType::getInt8PtrTy(llvmDialect), + // Type-erased function pointer. + LLVMType::getInt8PtrTy(llvmDialect), + // Number of inputs. + LLVMType::getIntNTy(llvmDialect, 32), + // Number of outputs. + LLVMType::getIntNTy(llvmDialect, 32), + }); +} + +// Get the LLVMType for npcomprt::ModuleDescriptor. +static LLVMType getModuleDescriptorTy(LLVM::LLVMDialect *llvmDialect) { + return LLVMType::getStructTy( + llvmDialect, { + // std::int32_t numFuncDescriptors; + LLVMType::getIntNTy(llvmDialect, 32), + // FuncDescriptor *functionDescriptors; + getFuncDescriptorTy(llvmDialect).getPointerTo(), + }); +} // Get the LLVMType for npcomprt::GlobalDescriptor. static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) { @@ -374,19 +402,7 @@ createFuncDescriptorArray(ArrayRef funcMetadatas, } // This must match FuncDescriptor in the runtime. - auto funcDescriptorTy = LLVMType::getStructTy( - llvmDialect, { - // Name length. - llvmI32Ty, - // Name chars. - LLVMType::getInt8PtrTy(llvmDialect), - // Type-erased function pointer. - LLVMType::getInt8PtrTy(llvmDialect), - // Number of inputs. - llvmI32Ty, - // Number of outputs. - llvmI32Ty, - }); + auto funcDescriptorTy = getFuncDescriptorTy(llvmDialect); auto funcDescriptorArrayTy = LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size()); auto funcDescriptorArrayGlobal = builder.create( @@ -458,11 +474,7 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray, auto *llvmDialect = builder.getContext()->getRegisteredDialect(); auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32); - auto moduleDescriptorTy = LLVMType::getStructTy( - llvmDialect, { - llvmI32Ty, - funcDescriptorArray.getType().getPointerTo(), - }); + auto moduleDescriptorTy = getModuleDescriptorTy(llvmDialect); // TODO: Ideally this symbol name would somehow be related to the module // name, if we could consistently assume we had one. // TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which @@ -490,8 +502,13 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray, builder.getI32IntegerAttr( funcDescriptorArray.getType().getArrayNumElements())), {0}); - updateDescriptor(builder.create(loc, funcDescriptorArray), - {1}); + + auto funcDecriptorArrayAddress = + builder.create(loc, funcDescriptorArray); + auto rawFuncDescriptorPtr = builder.create( + loc, getFuncDescriptorTy(llvmDialect).getPointerTo(), + funcDecriptorArrayAddress); + updateDescriptor(rawFuncDescriptorPtr, {1}); builder.create(loc, moduleDescriptor); return moduleDescriptorGlobal; diff --git a/test/E2E/lower-to-llvm.mlir b/test/E2E/lower-to-llvm.mlir index 71f1d80a2..7e5c1b36c 100644 --- a/test/E2E/lower-to-llvm.mlir +++ b/test/E2E/lower-to-llvm.mlir @@ -34,13 +34,14 @@ // CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]"> // CHECK: } -// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> +// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> { +// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> +// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> // CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> -// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> -// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> +// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> to !llvm<"{ i32, i8*, i8*, i32, i32 }*"> +// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> +// CHECK: llvm.return %[[VAL_5]] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> // CHECK: } npcomprt.module_metadata { @@ -150,15 +151,6 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor { // CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]"> // CHECK: } - -// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> { -// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> -// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 -// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> -// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]*"> -// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> -// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> -// CHECK: } npcomprt.module_metadata { npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32} npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}