mirror of https://github.com/llvm/torch-mlir
Consolidate LLVM definitions of runtime data structures.
This required making module descriptors hold a FuncDescriptor* instead of a pointer to array of FuncDescriptors as it previously did, which is innocuous (just requires an llvm.bitcast after the llvm.mlir.addressof).pull/1/head
parent
e228aa4b11
commit
df0d3fcaff
|
@ -93,7 +93,6 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
||||||
let results = (outs AnyUnrankedMemRef:$memref);
|
let results = (outs AnyUnrankedMemRef:$memref);
|
||||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||||
let verifier = "return ::verify$cppClass(*this);";
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
// TODO: verify exists and shape is compatible
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
||||||
|
|
|
@ -23,10 +23,38 @@ using mlir::LLVM::LLVMFuncOp;
|
||||||
using mlir::LLVM::LLVMType;
|
using mlir::LLVM::LLVMType;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Utilities.
|
// Descriptor types shared with the runtime.
|
||||||
|
//
|
||||||
|
// These correspond to the types in CompilerDataStructures.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// TODO: Move other descriptor types to here.
|
// Get the LLVMType for npcomprt::FuncDescriptor.
|
||||||
|
static LLVMType getFuncDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
return LLVMType::getStructTy(llvmDialect,
|
||||||
|
{
|
||||||
|
// Name length.
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
// Name chars.
|
||||||
|
LLVMType::getInt8PtrTy(llvmDialect),
|
||||||
|
// Type-erased function pointer.
|
||||||
|
LLVMType::getInt8PtrTy(llvmDialect),
|
||||||
|
// Number of inputs.
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
// Number of outputs.
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the LLVMType for npcomprt::ModuleDescriptor.
|
||||||
|
static LLVMType getModuleDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
return LLVMType::getStructTy(
|
||||||
|
llvmDialect, {
|
||||||
|
// std::int32_t numFuncDescriptors;
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
// FuncDescriptor *functionDescriptors;
|
||||||
|
getFuncDescriptorTy(llvmDialect).getPointerTo(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
||||||
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
@ -374,19 +402,7 @@ createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas,
|
||||||
}
|
}
|
||||||
|
|
||||||
// This must match FuncDescriptor in the runtime.
|
// This must match FuncDescriptor in the runtime.
|
||||||
auto funcDescriptorTy = LLVMType::getStructTy(
|
auto funcDescriptorTy = getFuncDescriptorTy(llvmDialect);
|
||||||
llvmDialect, {
|
|
||||||
// Name length.
|
|
||||||
llvmI32Ty,
|
|
||||||
// Name chars.
|
|
||||||
LLVMType::getInt8PtrTy(llvmDialect),
|
|
||||||
// Type-erased function pointer.
|
|
||||||
LLVMType::getInt8PtrTy(llvmDialect),
|
|
||||||
// Number of inputs.
|
|
||||||
llvmI32Ty,
|
|
||||||
// Number of outputs.
|
|
||||||
llvmI32Ty,
|
|
||||||
});
|
|
||||||
auto funcDescriptorArrayTy =
|
auto funcDescriptorArrayTy =
|
||||||
LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size());
|
LLVMType::getArrayTy(funcDescriptorTy, funcMetadatas.size());
|
||||||
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||||
|
@ -458,11 +474,7 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
||||||
auto *llvmDialect =
|
auto *llvmDialect =
|
||||||
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32);
|
auto llvmI32Ty = LLVMType::getIntNTy(llvmDialect, 32);
|
||||||
auto moduleDescriptorTy = LLVMType::getStructTy(
|
auto moduleDescriptorTy = getModuleDescriptorTy(llvmDialect);
|
||||||
llvmDialect, {
|
|
||||||
llvmI32Ty,
|
|
||||||
funcDescriptorArray.getType().getPointerTo(),
|
|
||||||
});
|
|
||||||
// TODO: Ideally this symbol name would somehow be related to the module
|
// TODO: Ideally this symbol name would somehow be related to the module
|
||||||
// name, if we could consistently assume we had one.
|
// name, if we could consistently assume we had one.
|
||||||
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
|
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
|
||||||
|
@ -490,8 +502,13 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
||||||
builder.getI32IntegerAttr(
|
builder.getI32IntegerAttr(
|
||||||
funcDescriptorArray.getType().getArrayNumElements())),
|
funcDescriptorArray.getType().getArrayNumElements())),
|
||||||
{0});
|
{0});
|
||||||
updateDescriptor(builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray),
|
|
||||||
{1});
|
auto funcDecriptorArrayAddress =
|
||||||
|
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
|
||||||
|
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
|
||||||
|
loc, getFuncDescriptorTy(llvmDialect).getPointerTo(),
|
||||||
|
funcDecriptorArrayAddress);
|
||||||
|
updateDescriptor(rawFuncDescriptorPtr, {1});
|
||||||
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
|
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
|
||||||
|
|
||||||
return moduleDescriptorGlobal;
|
return moduleDescriptorGlobal;
|
||||||
|
|
|
@ -34,13 +34,14 @@
|
||||||
// CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]">
|
// CHECK: llvm.return %[[VAL_13]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]">
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }"> {
|
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }"> {
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*">
|
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*">
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x { i32, i8*, i8*, i32, i32 }]*"> to !llvm<"{ i32, i8*, i8*, i32, i32 }*">
|
||||||
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [1 x { i32, i8*, i8*, i32, i32 }]* }">
|
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||||
|
// CHECK: llvm.return %[[VAL_5]] : !llvm<"{ i32, { i32, i8*, i8*, i32, i32 }* }">
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
npcomprt.module_metadata {
|
||||||
|
@ -150,15 +151,6 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]">
|
// CHECK: llvm.return %[[VAL_37]] : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]">
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }"> {
|
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm<"[3 x { i32, i8*, i8*, i32, i32 }]*">
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
|
||||||
// CHECK: llvm.return %[[VAL_4]] : !llvm<"{ i32, [3 x { i32, i8*, i8*, i32, i32 }]* }">
|
|
||||||
// CHECK: }
|
|
||||||
npcomprt.module_metadata {
|
npcomprt.module_metadata {
|
||||||
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||||
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
||||||
|
|
Loading…
Reference in New Issue