mirror of https://github.com/llvm/torch-mlir
Consolidate LLVM definitions of runtime data structures.
This required making module descriptors hold a FuncDescriptor* instead of a pointer to array of FuncDescriptors as it previously did, which is innocuous (just requires an llvm.bitcast after the llvm.mlir.addressof).pull/1/head
parent
e228aa4b11
commit
df0d3fcaff
|
@ -93,7 +93,6 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
|||
let results = (outs AnyUnrankedMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
// TODO: verify exists and shape is compatible
|
||||
}
|
||||
|
||||
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
||||
|
|
|
@ -23,10 +23,38 @@ using mlir::LLVM::LLVMFuncOp;
|
|||
using mlir::LLVM::LLVMType;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities.
|
||||
// Descriptor types shared with the runtime.
|
||||
//
|
||||
// These correspond to the types in CompilerDataStructures.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Move other descriptor types to here.
|
||||
// Get the LLVMType for npcomprt::FuncDescriptor.
|
||||
static LLVMType getFuncDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||
return LLVMType::getStructTy(llvmDialect,
|
||||
{
|
||||
// Name length.
|
||||
LLVMType::getIntNTy(llvmDialect, 32),
|
||||
// Name chars.
|
||||
LLVMType::getInt8PtrTy(llvmDialect),
|
||||
// Type-erased function pointer.
|
||||
LLVMType::getInt8PtrTy(llvmDialect),
|
||||
// Number of inputs.
|
||||
LLVMType::getIntNTy(llvmDialect, 32),
|
||||
// Number of outputs.
|
||||
LLVMType::getIntNTy(llvmDialect, 32),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVMType for npcomprt::ModuleDescriptor.
|
||||
static LLVMType getModuleDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||
return LLVMType::getStructTy(
|
||||
llvmDialect, {
|
||||
// std::int32_t numFuncDescriptors;
|
||||
LLVMType::getIntNTy(llvmDialect, 32),
|
||||
// FuncDescriptor *functionDescriptors;
|
||||
getFuncDescriptorTy(llvmDialect).getPointerTo(),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
||||
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||
|
@ -374,19 +402,7 @@ createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas,
|
|||
}
|
||||
|
||||
// This must match FuncDescriptor in the runtime.
|
||||
auto funcDescriptorTy = LLVMType::getStructTy(
|
||||
llvmDialect, {
|
||||
// Name length.
|
||||
llvmI32Ty,
|
||||
// Name chars.
|
||||
LLVMType::getInt8PtrTy(llvmDialect),
|
||||
// Type-erased function pointer.
|
||||
LLVMType::getInt8PtrTy(llvmDialect),
|
||||
// Number of inputs.
|
||||
llvmI32Ty,
|
||||
// Number of outputs.
|
||||
llvmI32Ty,
|
||||
});
|
||||
auto funcDescriptorTy = getFuncDescriptorTy(llvmDialect);
|
||||
auto funcDescriptorArrayTy =
|
||||
LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size());
|
||||
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
|
@ -458,11 +474,7 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
|||
auto *llvmDialect =
|
||||
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32);
|
||||
auto moduleDescriptorTy = LLVMType::getStructTy(
|
||||
llvmDialect, {
|
||||
llvmI32Ty,
|
||||
funcDescriptorArray.getType().getPointerTo(),
|
||||
});
|
||||
auto moduleDescriptorTy = getModuleDescriptorTy(llvmDialect);
|
||||
// TODO: Ideally this symbol name would somehow be related to the module
|
||||
// name, if we could consistently assume we had one.
|
||||
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
|
||||
|
@ -490,8 +502,13 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
|||
builder.getI32IntegerAttr(
|
||||
funcDescriptorArray.getType().getArrayNumElements())),
|
||||
{0});
|
||||
updateDescriptor(builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray),
|
||||
{1});
|
||||
|
||||
auto funcDecriptorArrayAddress =
|
||||
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
|
||||
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, getFuncDescriptorTy(llvmDialect).getPointerTo(),
|
||||
funcDecriptorArrayAddress);
|
||||
updateDescriptor(rawFuncDescriptorPtr, {1});
|
||||
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
|
||||
|
||||
return moduleDescriptorGlobal;
|
||||
|
|
|
@ -34,13 +34,14 @@
|
|||
// CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]">
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*">
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> to !llvm<"{ i32, i8*, i8*, i32, i32 }*">
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||
// CHECK: llvm.return %[[VAL_5]] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||
// CHECK: }
|
||||
|
||||
npcomprt.module_metadata {
|
||||
|
@ -150,15 +151,6 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
|||
// CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]">
|
||||
// CHECK: }
|
||||
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]*">
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
||||
// CHECK: }
|
||||
npcomprt.module_metadata {
|
||||
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
||||
|
|
Loading…
Reference in New Issue