[refbackrt] Scalar arg support

* Adds f32 scalar argument support across the ABI boundary.
* Adds support for passing input type / shape information
  across the ABI boundary
* Adds support for parsing / creating input FloatAttr's in
  `npcomp-run-mlir`
pull/197/head
Bryce Arden 2021-03-10 11:53:03 -06:00 committed by Sean Silva
parent 703428eff4
commit 4591884d06
17 changed files with 1085 additions and 269 deletions

View File

@ -68,12 +68,47 @@ def Refbackrt_FuncMetadataOp
let description = [{ let description = [{
Runtime metadata for a single func. Runtime metadata for a single func.
TODO: Augment this with information for type/shape checking of arguments. Contains type / shape information for arguments as described below:
* ArgType(s):
Integer value from `CompilerDataStructures.h` for each argument
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
* ElementType(s):
Certain input ArgType's also have an element type (e.g. Tensor<float>,
List<int>, etc.)
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
* Rank(s):
Integer value indicating the rank for each argument.
* Shape(s):
Flattened hyper-rectangular representation of the shapes for each argument.
Since each shape's size varies based on the Rank, we pad out the shapes
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
for details.
Shapes Example:
constexpr int kMaxRank = 6;
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
inputShapes = dense<...> : tensor<12xi32>
// 2 shapes with 6 elements each so that the LowerToLLVM pass
// where only the first `rank` values in each shape are valid.
//
// can update the struct(s) by just grabbing a pointer at
// %shape_ptr = %base + (kMaxRank * argIndex)
}]; }];
let arguments = (ins let arguments = (ins
FlatSymbolRefAttr:$funcName, FlatSymbolRefAttr:$funcName,
I32Attr:$numInputs, I32Attr:$numInputs,
I32Attr:$numOutputs I32Attr:$numOutputs,
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
OptionalAttr<I32ElementsAttr>:$inputRanks,
OptionalAttr<I32ElementsAttr>:$inputShapes,
// I32ElementsAttr:$inputIsStatic,
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
OptionalAttr<I32ElementsAttr>:$outputRanks,
OptionalAttr<I32ElementsAttr>:$outputShapes
//I32ElementsAttr:$outputIsStatic
); );
let results = (outs); let results = (outs);
let assemblyFormat = "attr-dict"; let assemblyFormat = "attr-dict";

View File

@ -31,6 +31,8 @@ public:
return std::memcmp(ptr, other.ptr, length) == 0; return std::memcmp(ptr, other.ptr, length) == 0;
} }
const char* str() { return ptr; }
private: private:
const char *ptr; const char *ptr;
std::size_t length; std::size_t length;

View File

@ -22,6 +22,7 @@
#define NPCOMP_RUNTIME_USERAPI_H #define NPCOMP_RUNTIME_USERAPI_H
#include "npcomp/RefBackend/Runtime/Support.h" #include "npcomp/RefBackend/Runtime/Support.h"
#include <array>
#include <atomic> #include <atomic>
#include <cstdlib> #include <cstdlib>
@ -105,9 +106,11 @@ private:
// The available data types. // The available data types.
enum class ElementType : std::int32_t { enum class ElementType : std::int32_t {
NONE,
F32, F32,
}; };
std::int32_t getElementTypeByteSize(ElementType type); std::int32_t getElementTypeByteSize(ElementType type);
StringRef getElementTypeAsStringRef(ElementType type);
// Representation of a tensor. // Representation of a tensor.
class Tensor : public RefTarget { class Tensor : public RefTarget {
@ -124,6 +127,12 @@ public:
static Tensor *createRaw(ArrayRef<std::int32_t> extents, static Tensor *createRaw(ArrayRef<std::int32_t> extents,
ElementType elementType, void *data); ElementType elementType, void *data);
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
// Same as `create`, but returns a raw pointer.
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
ElementType getElementType() const { return elementType; } ElementType getElementType() const { return elementType; }
std::int32_t getRank() const { return rank; } std::int32_t getRank() const { return rank; }
void *getData() const { return data; } void *getData() const { return data; }
@ -169,6 +178,7 @@ private:
_(None) \ _(None) \
_(Bool) \ _(Bool) \
_(Int) \ _(Int) \
_(Float) \
_(Double) _(Double)
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor) #define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
@ -193,15 +203,23 @@ struct RtValue final {
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; } RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {} RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
bool isInt() const { return Tag::Int == tag; } bool isInt() const { return Tag::Int == tag; }
bool toInt() const { int64_t toInt() const {
assert(isInt()); assert(isInt());
return payload.asInt; return payload.asInt;
} }
// Float
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
bool isFloat() const { return Tag::Float == tag; }
float toFloat() const {
assert(isFloat());
return payload.asFloat;
}
// Double // Double
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; } RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
bool isDouble() const { return Tag::Double == tag; } bool isDouble() const { return Tag::Double == tag; }
bool toDouble() const { double toDouble() const {
assert(isDouble()); assert(isDouble());
return payload.asDouble; return payload.asDouble;
} }
@ -227,6 +245,11 @@ struct RtValue final {
return false; return false;
} }
// Scalar
bool isScalar() const {
return isBool() || isInt() || isFloat() || isDouble();
}
// RtValue (downcast) // RtValue (downcast)
const RtValue &toRtValue() const { return *this; } const RtValue &toRtValue() const { return *this; }
RtValue &toRtValue() { return *this; } RtValue &toRtValue() { return *this; }
@ -298,6 +321,7 @@ private:
union Payload { union Payload {
bool asBool; bool asBool;
int64_t asInt; int64_t asInt;
float asFloat;
double asDouble; double asDouble;
void *asVoidPtr; void *asVoidPtr;
}; };
@ -313,19 +337,72 @@ private:
// This is the main entry point that users interact with. // This is the main entry point that users interact with.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
enum class ArgType : std::uint32_t {
kNone = 0,
kTensor,
kF32,
kF64,
};
StringRef getArgTypeAsStringRef(ArgType type);
// Maximum rank supported across the ABI boundary
constexpr static int kMaxRank = 6;
struct InputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
};
struct OutputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
// TODO(brycearden): Add checks for whether output buffers alias to input
// buffers and populate field(s) here indicating that case
};
// Maximum input or output arity.
constexpr static int kMaxArity = 20;
// Metadata for a particular function. // Metadata for a particular function.
// TODO: Add arg types.
struct FunctionMetadata { struct FunctionMetadata {
std::int32_t numInputs; std::int32_t numInputs;
std::int32_t numOutputs; std::int32_t numOutputs;
std::array<InputArgInfo, kMaxArity> inputArgInfos;
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
}; };
// Opaque forward declaration of module descriptor type. This is the type // Opaque forward declaration of module descriptor type. This is the type
// created by the compiler in the module binary. // created by the compiler in the module binary.
struct ModuleDescriptor; struct ModuleDescriptor;
// Maximum input or output arity. // Verifies that the input RtValue arg types match what the user provides
constexpr static int kMaxArity = 20; // matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info);
// Verifies that the input RtValue shapes matches what the user provides
// matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueShapes(const RtValue &value,
const InputArgInfo &info);
// Creates an RtValue of the right type from the output metadata
// provided by the compiled module
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
// Low-level invocation API. The number of inputs and outputs should be correct // Low-level invocation API. The number of inputs and outputs should be correct
// and match the results of getMetadata. // and match the results of getMetadata.

View File

@ -54,6 +54,16 @@ static LogicalResult verify(FuncMetadataOp op) {
return op.emitError() << "must agree on number of inputs"; return op.emitError() << "must agree on number of inputs";
if (op.numOutputs() != func.getNumResults()) if (op.numOutputs() != func.getNumResults())
return op.emitError() << "must agree on number of outputs"; return op.emitError() << "must agree on number of outputs";
if (op.numInputs() > 0) {
if (op.numInputs() != op.inputArgTypes()->size()) {
return op.emitError() << "number of inputTypes must match number of inputs";
}
}
if (op.numOutputs() > 0) {
if (op.numOutputs() != op.outputArgTypes()->size())
return op.emitError() << "number of outputTypes must match number of outputs";
}
return success(); return success();
} }

View File

@ -12,6 +12,8 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "npcomp/RefBackend/RefBackend.h" #include "npcomp/RefBackend/RefBackend.h"
#include <sstream>
using namespace refback; using namespace refback;
using namespace mlir; using namespace mlir;
using llvm::Error; using llvm::Error;
@ -73,6 +75,22 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return refbackrt::MutableArrayRef<T>(a.data(), a.size()); return refbackrt::MutableArrayRef<T>(a.data(), a.size());
} }
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
static constexpr char *kDynamicDimAsString = "?";
std::stringstream ss;
ss << "(";
for (int i = 0; i < extents.size(); i++) {
if (extents[i] < 0)
ss << kDynamicDimAsString;
else
ss << extents[i];
if (i != extents.size() - 1)
ss << "x";
}
ss << ")";
return ss.str();
}
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>> llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
JITModule::invoke(llvm::StringRef functionName, JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::RtValue> inputs) { llvm::ArrayRef<refbackrt::RtValue> inputs) {
@ -80,12 +98,47 @@ JITModule::invoke(llvm::StringRef functionName,
if (refbackrt::failed(refbackrt::getMetadata( if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toRefbackrt(functionName), metadata))) descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName)); return make_string_error("unknown function: " + Twine(functionName));
SmallVector<refbackrt::RtValue, 6> outputs( SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size())) if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) + return make_string_error("invoking '" + Twine(functionName) +
"': expected " + Twine(metadata.numInputs) + "': expected " + Twine(metadata.numInputs) +
" inputs"); " inputs");
// Verify user input types and shapes match what the compiler expects
for (int i = 0; i < metadata.numInputs; i++) {
auto &input = inputs[i];
auto &inputArgInfo = metadata.inputArgInfos[i];
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) +
"': input argument type mismatch. actual (provided by user): " +
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
Twine(i) + "). " + "actual (provided by user): " +
stringifyShape(input.toTensor()->getExtents()) +
", expected (from compiler): " +
stringifyShape(refbackrt::ArrayRef<int32_t>(
inputArgInfo.extents.data(), inputArgInfo.rank)));
}
// Create the correct output RtValue based on FuncMetadata,
// which contains the arg types (scalar, Tensor, etc.), element types (only
// applicable if not scalar) and shapes (also only applicable if not scalar)
//
// Currently we have to give each RtValue an output type so that we know
// how to pack / unpack the outputs properly across the ABI boundary in
// refbackrt::invoke. As a result, we can't just rely on the default
// construction of each output argument type (otherwise RtValue will have
// Tag::kNone) currently without passing the ArgInfo structs down to the
// Runtime level, so we deal with the output type creation here.
for (int i = 0; i < metadata.numOutputs; i++) {
outputs[i] = std::move(
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]));
}
refbackrt::invoke( refbackrt::invoke(
descriptor, toRefbackrt(functionName), toRefbackrt(inputs), descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size()))); toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));

View File

@ -34,25 +34,70 @@ using mlir::LLVM::LLVMVoidType;
// These correspond to the types in CompilerDataStructures.h // These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MaxRank that the refbackrt ABI lowering is capable of handling
// NOTE: This parameter must stay consistent with
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
static constexpr int kMaxRank = 6;
static LLVMPointerType getInt8PointerType(MLIRContext *context) { static LLVMPointerType getInt8PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 8)); return LLVMPointerType::get(IntegerType::get(context, 8));
} }
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 32));
}
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
// Get the LLVM type for refbackrt::FuncDescriptor. // Get the LLVM type for refbackrt::FuncDescriptor.
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) { static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(context, return LLVMStructType::getLiteral(
{ context, {
// Name length. // Name length.
IntegerType::get(context, 32), IntegerType::get(context, 32),
// Name chars. // Name chars.
getInt8PointerType(context), getInt8PointerType(context),
// Type-erased function pointer. // Type-erased function pointer.
getInt8PointerType(context), getInt8PointerType(context),
// Number of inputs. // Number of inputs.
IntegerType::get(context, 32), IntegerType::get(context, 32),
// Number of outputs. // Number of outputs.
IntegerType::get(context, 32), IntegerType::get(context, 32),
}); // Argument descriptors
LLVMPointerType::get(getInputDescriptorTy(context)),
// Result Descriptors
LLVMPointerType::get(getOutputDescriptorTy(context)),
});
} }
// Get the LLVM type for refbackrt::ModuleDescriptor. // Get the LLVM type for refbackrt::ModuleDescriptor.
@ -92,8 +137,8 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
// TODO: Deduplicate strings. // TODO: Deduplicate strings.
std::string msgNulTerminated = msg.getValue().str(); std::string msgNulTerminated = msg.getValue().str();
msgNulTerminated.push_back('\0'); msgNulTerminated.push_back('\0');
auto arrayTy = LLVMArrayType::get( auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
IntegerType::get(module.getContext(), 8), msgNulTerminated.size()); msgNulTerminated.size());
OpBuilder::InsertionGuard guard(builder); OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(module.getBody()); builder.setInsertionPointToStart(module.getBody());
@ -129,9 +174,9 @@ public:
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(), auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
op.msgAttr(), rewriter, op.getLoc()); op.msgAttr(), rewriter, op.getLoc());
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp); auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
auto c0 = rewriter.create<LLVM::ConstantOp>( auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
op.getLoc(), IntegerType::get(context, 32), IntegerType::get(context, 32),
rewriter.getI32IntegerAttr(0)); rewriter.getI32IntegerAttr(0));
auto msg = auto msg =
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context), rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
msgArray, ValueRange({c0, c0})); msgArray, ValueRange({c0, c0}));
@ -181,16 +226,181 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32); auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
DenseMap<StringRef, LLVM::GlobalOp> globalsByName; DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
for (auto funcMetadata : funcMetadatas) { for (auto funcMetadata : funcMetadatas) {
auto arrayTy = auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), funcMetadata.funcName().size());
funcMetadata.funcName().size());
std::string llvmSymbolName = std::string llvmSymbolName =
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str(); (Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
auto global = builder.create<LLVM::GlobalOp>( auto global = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal, loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName())); llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
globalsByName[funcMetadata.funcName()] = global; globalsByName[funcMetadata.funcName()] = global;
// Create constants for the input / output shapes
if (funcMetadata.inputShapes().hasValue()) {
auto i32ArrayInputSymbolName =
(Twine("__npcomp_internal_constant_input_shapes_") +
funcMetadata.funcName())
.str();
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
auto inputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayInputSymbolName,
/*value=*/funcMetadata.inputShapes().getValue());
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
}
if (funcMetadata.outputShapes().hasValue()) {
auto i32ArrayOutputSymbolName =
(Twine("__npcomp_internal_constant_output_shapes_") +
funcMetadata.funcName())
.str();
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
auto outputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayOutputSymbolName,
/*value=*/funcMetadata.outputShapes().getValue());
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
}
}
auto updateDescriptor = [&](Value &descriptor, Value value,
std::initializer_list<int32_t> position) {
descriptor = builder.create<LLVM::InsertValueOp>(
loc, descriptor, value,
/*position=*/builder.getI32ArrayAttr(position));
};
auto updateDescriptorWithI32Attr =
[&](Value &descriptor, Attribute attr,
std::initializer_list<int32_t> position) {
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
updateDescriptor(descriptor, constant, position);
};
// Create global input descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmInputSymbolName =
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
auto inputDescriptorArrayTy =
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, inputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value inputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
for (int i = 0; i < funcMetadata.numInputs(); i++) {
// Arg Type
if (!funcMetadata.inputArgTypes().hasValue())
funcMetadata.emitError()
<< "numInputs > 0 but there are no inputArgTypes?";
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto inputShapesType =
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.inputRanks()->getValue(i);
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
// Shape
// Each shape array is derived by offseting of kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, inputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
inputDescriptorsByName[funcMetadata.funcName()] =
std::move(inputDescriptorArrayGlobal);
}
// Create global output descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmOutputSymbolName =
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
auto outputDescriptorArrayTy =
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, outputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value outputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
for (int i = 0; i < funcMetadata.numOutputs(); i++) {
if (!funcMetadata.outputArgTypes().hasValue())
funcMetadata.emitError()
<< "numOutputs > 0 but there are no outputArgTypes?";
// Arg Type
updateDescriptorWithI32Attr(outputDescriptorArray,
funcMetadata.outputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto outputShapesType =
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.outputRanks()->getValue(i);
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
// Shapes
// Offset by kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, outputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
outputDescriptorsByName[funcMetadata.funcName()] =
outputDescriptorArrayGlobal;
} }
// This must match FuncDescriptor in the runtime. // This must match FuncDescriptor in the runtime.
@ -201,31 +411,23 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal, loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
"__npcomp_func_descriptors", "__npcomp_func_descriptors",
/*value=*/Attribute()); /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder); OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&funcDescriptorArrayGlobal.initializer()); builder.createBlock(&funcDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
// Build the initializer. // Build the initializer.
Value funcDescriptorArray = Value funcDescriptorArray =
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy); builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
auto updateDescriptor = [&](Value value,
std::initializer_list<int32_t> position) {
funcDescriptorArray = builder.create<LLVM::InsertValueOp>(
loc, funcDescriptorArray, value,
/*position=*/builder.getI32ArrayAttr(position));
};
auto updateDescriptorWithI32Attr =
[&](Attribute attr, std::initializer_list<int32_t> position) {
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
updateDescriptor(constant, position);
};
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) { for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
auto funcMetadata = funcMetadataAndIndex.value(); auto funcMetadata = funcMetadataAndIndex.value();
int32_t index = funcMetadataAndIndex.index(); int32_t index = funcMetadataAndIndex.index();
// Name length. // Name length.
updateDescriptorWithI32Attr( updateDescriptorWithI32Attr(
funcDescriptorArray,
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0}); builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
// Name chars. // Name chars.
@ -234,7 +436,7 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
auto funcNamePtr = builder.create<LLVM::GEPOp>( auto funcNamePtr = builder.create<LLVM::GEPOp>(
loc, getInt8PointerType(builder.getContext()), funcNameArray, loc, getInt8PointerType(builder.getContext()), funcNameArray,
ValueRange({c0, c0})); ValueRange({c0, c0}));
updateDescriptor(funcNamePtr, {index, 1}); updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
// Function pointer. // Function pointer.
// //
@ -247,13 +449,31 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName()); loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>( auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
loc, getInt8PointerType(builder.getContext()), funcAddress); loc, getInt8PointerType(builder.getContext()), funcAddress);
updateDescriptor(typeErasedFuncAddress, {index, 2}); updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
// Number of inputs. // Number of inputs.
updateDescriptorWithI32Attr(funcMetadata.numInputsAttr(), {index, 3}); updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numInputsAttr(), {index, 3});
// Number of outputs. // Number of outputs.
updateDescriptorWithI32Attr(funcMetadata.numOutputsAttr(), {index, 4}); updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numOutputsAttr(), {index, 4});
// Input descriptors
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, inputDescriptorsByName[funcMetadata.funcName()]);
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
inputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
// Output descriptors
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, outputDescriptorsByName[funcMetadata.funcName()]);
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
outputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
} }
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray); builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
@ -379,6 +599,16 @@ static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
/*memorySpace=*/0)); /*memorySpace=*/0));
} }
static Type getDoubleType(MLIRContext *context) {
LLVMTypeConverter converter(context);
return converter.convertType(FloatType::getF64(context));
}
static Type getFloatType(MLIRContext *context) {
LLVMTypeConverter converter(context);
return converter.convertType(FloatType::getF32(context));
}
// Writes out the logical results of the wrapper function through the void** // Writes out the logical results of the wrapper function through the void**
// passed on the ABI boundary. Because LLVM (and hence llvm.func) // passed on the ABI boundary. Because LLVM (and hence llvm.func)
// only supports a single return type (or void/no results), the logic here needs // only supports a single return type (or void/no results), the logic here needs
@ -393,12 +623,18 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
return; return;
Value result = callToWrapped.getResult(0); Value result = callToWrapped.getResult(0);
auto ty = result.getType(); auto ty = result.getType();
// 1 logical result. // 1 logical result.
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) { if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
Value addr = Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc); getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr); builder.create<LLVM::StoreOp>(loc, result, addr);
return; return;
} else if (ty == getFloatType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
} }
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!"); assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
auto structType = ty.cast<LLVMStructType>(); auto structType = ty.cast<LLVMStructType>();

View File

@ -21,6 +21,16 @@
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
// Since input/output shapes are not hyper-rectangular we specify
// a maximum rank for each input shape such that shapes are padded
// out to kMaxRank at the ABI boundary. That way we can represent
// shapes using a traditional DenseElementsAttr.
//
// NOTE: When changing this parameter, also change the same `kMaxRank`
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
// stays consistent.
static constexpr int kMaxRank = 6;
// Get the type used to represent MemRefType `type` on ABI boundaries. // Get the type used to represent MemRefType `type` on ABI boundaries.
// For convenience we do a cast to MemRefType internally. // For convenience we do a cast to MemRefType internally.
static Type getABIMemrefType(Type type) { static Type getABIMemrefType(Type type) {
@ -38,7 +48,99 @@ static bool expressibleWithRefbackrtABI(FunctionType type) {
// Currently, only memref types can be exposed at refbackrt ABI boundaries. // Currently, only memref types can be exposed at refbackrt ABI boundaries.
return llvm::all_of( return llvm::all_of(
llvm::concat<const Type>(type.getInputs(), type.getResults()), llvm::concat<const Type>(type.getInputs(), type.getResults()),
[](Type t) { return t.isa<MemRefType>(); }); [](Type t) {
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
});
}
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
// Must stay aligned with CompilerDataStructures::ABIArgType enum
static uint32_t getIntReprForABIType(Type type) {
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
return 1;
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 2;
case 64:
return 3;
default:
assert(false && "Unsupported float bit width");
}
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
}
// assert(false && "couldn't get IntReprForABIType");
return -1;
}
// Must stay aligned with CompilerDataStructures::ABIElementType enum
static uint32_t getIntReprForABIElementType(Type type) {
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
auto elemTy = shapedTy.getElementType();
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 1;
default:
assert(false && "Unsupported tensor element type");
}
}
}
return 0;
}
static SmallVector<int32_t, kMaxRank>
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasRank()) {
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
auto shape = shapedType.getShape();
auto shapeRank = shapedType.getRank();
if (shapeRank <= maxRank) {
SmallVector<int32_t, kMaxRank> extendedShape;
// Push back all the values of the shape
for (auto extentAndIndex : llvm::enumerate(shape)) {
auto extent = extentAndIndex.value();
auto index = extentAndIndex.index();
if (shapedType.isDynamic(index)) {
extendedShape.push_back(-1);
} else {
extendedShape.push_back(extent);
}
}
// Pad whatever is left so we have even vectors
auto padRank = maxRank - shapeRank;
for (int i = 0; i < padRank; i++)
extendedShape.push_back(0xDEAD'BEEF);
return extendedShape;
} else {
assert(false && "unsupported rank");
}
}
// Represent Scalar's as all 1's.
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
int32_t getRankForType(Type type) {
// Returns a rank of -1 if the tensor is unranked
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasRank() ? shapedType.getRank() : -1;
}
return 0;
}
uint32_t hasStaticShape(Type type) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasStaticShape() ? 1 : 0;
}
// Assume scalars and non-shaped type things are static
return 1;
} }
static LogicalResult createModuleMetadata(ModuleOp module) { static LogicalResult createModuleMetadata(ModuleOp module) {
@ -59,10 +161,131 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
} }
// TODO: Add richer information here such as expected shapes and element // TODO: Add richer information here such as expected shapes and element
// types. // types.
builder.create<refbackrt::FuncMetadataOp>( SmallVector<uint32_t, 6> inputABIArgTypes;
func.getLoc(), builder.getSymbolRefAttr(func.getName()), SmallVector<uint32_t, 6> inputABIElementTypes;
builder.getI32IntegerAttr(func.getNumArguments()), SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
builder.getI32IntegerAttr(func.getNumResults())); SmallVector<uint32_t, 6> inputABIRanks;
// SmallVector<uint32_t, 6> inputIsStatic;
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
inputABIShapes.push_back(
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
inputABIRanks.push_back(getRankForType(inputArgType));
// inputIsStatic.push_back(hasStaticShape(inputArgType));
}
SmallVector<uint32_t, 6> outputABIArgTypes;
SmallVector<uint32_t, 6> outputABIElementTypes;
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
SmallVector<uint32_t, 6> outputABIRanks;
SmallVector<uint32_t, 6> outputIsStatic;
for (const auto &outputArgType : func.getCallableResults()) {
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
outputABIElementTypes.push_back(
getIntReprForABIElementType(outputArgType));
outputABIShapes.push_back(
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
outputABIRanks.push_back(getRankForType(outputArgType));
// outputIsStatic.push_back(hasStaticShape(outputArgType));
}
auto i32Type = builder.getIntegerType(32);
auto inputABIDataType =
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
auto inputABIElementType =
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
auto inputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
kMaxRank},
i32Type);
auto inputABIRanksType =
RankedTensorType::get(inputABIRanks.size(), i32Type);
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
// i32Type);
auto outputABIDataType =
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
auto outputABIElementType =
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
auto outputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
kMaxRank},
i32Type);
auto outputABIRanksType =
RankedTensorType::get(outputABIRanks.size(), i32Type);
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
// i32Type);
// TODO(brycearden): I'm sure there's a cleaner way to do this
auto flattenABIShapes =
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
SmallVector<int32_t, 32> ret;
for (auto &shape : shapes)
for (auto &dim : shape)
ret.push_back(dim);
return ret;
};
SmallVector<NamedAttribute, 16> namedAttrs;
// Add attributes that are valid for every func (funcName, numInputs,
// numOutputs)
namedAttrs.push_back(
std::make_pair(Identifier::get("funcName", module.getContext()),
builder.getSymbolRefAttr(func.getName())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numInputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumArguments())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numOutputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumResults())));
if (inputABIArgTypes.size()) {
// Only add input information if there are inputs
namedAttrs.push_back(std::make_pair(
Identifier::get("inputArgTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIDataType,
llvm::makeArrayRef(inputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputElementTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIElementType,
llvm::makeArrayRef(inputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputRanks", func.getContext()),
DenseIntElementsAttr::get(inputABIRanksType,
llvm::makeArrayRef(inputABIRanks))));
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
namedAttrs.push_back(std::make_pair(
Identifier::get("inputShapes", func.getContext()),
DenseIntElementsAttr::get(
inputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
}
if (outputABIArgTypes.size()) {
// Only add output information if there are outptus
namedAttrs.push_back(std::make_pair(
Identifier::get("outputArgTypes", func.getContext()),
DenseIntElementsAttr::get(outputABIDataType,
llvm::makeArrayRef(outputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputElementTypes", func.getContext()),
DenseIntElementsAttr::get(
outputABIElementType,
llvm::makeArrayRef(outputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputRanks", func.getContext()),
DenseIntElementsAttr::get(outputABIRanksType,
llvm::makeArrayRef(outputABIRanks))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputShapes", func.getContext()),
DenseIntElementsAttr::get(
outputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
}
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{}, namedAttrs);
if (!expressibleWithRefbackrtABI(func.getType())) if (!expressibleWithRefbackrtABI(func.getType()))
return func.emitError() << "func not expressible with refbackrt ABI"; return func.emitError() << "func not expressible with refbackrt ABI";

View File

@ -23,6 +23,40 @@ namespace refbackrt {
// LowerToLLVM.cpp for more details. // LowerToLLVM.cpp for more details.
typedef void ABIFunc(void **, void **); typedef void ABIFunc(void **, void **);
enum class ABIArgType : std::uint32_t {
kNone = 0,
kMemref,
kF32,
kF64,
};
enum class ABIElementType : std::uint32_t {
kNone = 0,
kF32,
};
struct InputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
// std::int32_t isStatic;
};
struct OutputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
//std::int32_t isStatic;
};
struct FuncDescriptor { struct FuncDescriptor {
// The length of the function name. // The length of the function name.
std::int32_t nameLen; std::int32_t nameLen;
@ -35,9 +69,9 @@ struct FuncDescriptor {
std::int32_t numInputs; std::int32_t numInputs;
// The number of outputs of the function. // The number of outputs of the function.
std::int32_t numOutputs; std::int32_t numOutputs;
// TODO: Add arg/result descriptors and other metadata. // TODO: Add shape checking to arg / result descriptor(s)
// With those descriptors we can do type and shape checking for each InputDescriptor *inputDescriptors;
// argument. OutputDescriptor *outputDescriptors;
}; };
// The top-level entry point of the module metadata emitted by the // The top-level entry point of the module metadata emitted by the

View File

@ -118,12 +118,38 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) { std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
switch (type) { switch (type) {
case ElementType::NONE:
return 0;
case ElementType::F32: case ElementType::F32:
return 4; return 4;
} }
llvm_unreachable("unsupported dtype"); llvm_unreachable("unsupported dtype");
} }
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
switch (type) {
case ElementType::NONE:
return "NONE";
case ElementType::F32:
return "F32";
}
llvm_unreachable("unsupported element type string");
}
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
switch (type) {
case ArgType::kNone:
return "kNone";
case ArgType::kTensor:
return "kTensor";
case ArgType::kF32:
return "kF32";
case ArgType::kF64:
return "kF64";
}
llvm_unreachable("unsupported arg type string");
}
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type, Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
void *data) { void *data) {
return Ref<Tensor>(createRaw(extents, type, data)); return Ref<Tensor>(createRaw(extents, type, data));
@ -192,10 +218,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's. // Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
// TODO: Avoid the deep copy. It makes the later lifetime management code // TODO: Avoid the deep copy. It makes the later lifetime management code
// more complex though (and maybe impossible given the current abstractions). // more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) { //
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
}
// Create a type-erased list of "packed inputs" to pass to the // Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to // LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
// one LLVM/C ABI argument to the underlying function. // one LLVM/C ABI argument to the underlying function.
@ -204,16 +227,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// "explode" the unranked memref descriptors on the underlying function // "explode" the unranked memref descriptors on the underlying function
// into separate arguments for the rank and pointer-to-descriptor. // into separate arguments for the rank and pointer-to-descriptor.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank); auto idx = 2 * i;
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor); if (inputs[i].isTensor()) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
} else if (inputs[i].isScalar()) {
packedInputs[idx] = ToVoidPtr(&inputs[i]);
} else {
assert(false && "unsupported input RtValue type");
}
} }
// Create a type-erased list of "packed output" to pass to the // Create a type-erased list of "packed output" to pass to the
// LLVM/C ABI wrapper function. // LLVM/C ABI wrapper function.
// //
// Due to how StandardToLLVM lowering works, each packedOutput pointer // Due to how StandardToLLVM lowering works, each packedOutput pointer
// corresponds to a single UnrankedMemref (not "exploded"). // corresponds to a single UnrankedMemref (not "exploded").
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]); if (outputs[i].isTensor()) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
} else if (outputs[i].isScalar()) {
packedOutputs[i] = ToVoidPtr(&outputs[i]);
}
} }
// Actually invoke the function! // Actually invoke the function!
@ -223,11 +260,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// TODO: Avoid needing to make a deep copy. // TODO: Avoid needing to make a deep copy.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata. // TODO: Have compiler emit the element type in the metadata.
auto elementType = ElementType::F32; if (outputs[i].isTensor()) {
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor( auto elementType = ElementType::F32;
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor, Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
elementType); outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
outputs[i] = RtValue(Ref<Tensor>(tensor)); elementType);
outputs[i] = RtValue(Ref<Tensor>(tensor));
} else if (outputs[i].isFloat()) {
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
}
} }
// Now, we just need to free all the UnrankedMemref's that we created. // Now, we just need to free all the UnrankedMemref's that we created.
@ -239,24 +280,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Free the output buffers. // Free the output buffers.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr; if (outputs[i].isRef()) {
// Multiple returned memrefs can point into the same underlying void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// malloc allocation. Do a linear scan to see if any of the previously // Multiple returned memrefs can point into the same underlying
// deallocated buffers already freed this pointer. // malloc allocation. Do a linear scan to see if any of the previously
bool bufferNeedsFreeing = true; // deallocated buffers already freed this pointer.
for (int j = 0; j < i; j++) { bool bufferNeedsFreeing = true;
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) for (int j = 0; j < i; j++) {
bufferNeedsFreeing = false; if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
} }
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
} }
// Free the input buffers. // Free the input buffers.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr; void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true; bool bufferNeedsFreeing = true;
for (int j = 0, je = outputs.size(); j < je; j++) { for (int j = 0, je = outputs.size(); j < je; j++) {
if (!outputs[j].isRef())
continue;
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr) if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false; bufferNeedsFreeing = false;
} }
@ -274,6 +321,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Free the output descriptors. // Free the output descriptors.
for (int i = 0, e = outputs.size(); i < e; i++) { for (int i = 0, e = outputs.size(); i < e; i++) {
if (!outputs[i].isRef())
continue;
// The LLVM lowering guarantees that each returned unranked memref // The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special // descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's. // like we had to do for the allocatedPtr's.
@ -281,10 +330,81 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
} }
// Free the input descriptors. // Free the input descriptors.
for (int i = 0, e = inputs.size(); i < e; i++) { for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
std::free(inputUnrankedMemrefs[i].descriptor); std::free(inputUnrankedMemrefs[i].descriptor);
} }
} }
static InputArgInfo
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
InputArgInfo ret;
// Set arg / element types accordingly
switch (inputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
default:
assert(false && "need to update external internal map");
}
// Extract shape information
ret.rank = inputDescriptor.rank;
for (int i = 0; i < inputDescriptor.rank; i++) {
ret.extents[i] = inputDescriptor.extents[i];
}
return ret;
}
static OutputArgInfo
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
OutputArgInfo ret;
// Set arg / element types accordingly
switch (outputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
default:
assert(false && "need to update external internal map");
}
// Extract shape information
ret.rank = outputDescriptor.rank;
for (int i = 0; i < outputDescriptor.rank; i++) {
ret.extents[i] = outputDescriptor.extents[i];
}
return ret;
}
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor, LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName, StringRef functionName,
FunctionMetadata &outMetadata) { FunctionMetadata &outMetadata) {
@ -293,5 +413,107 @@ LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
return failure(); return failure();
outMetadata.numInputs = descriptor->numInputs; outMetadata.numInputs = descriptor->numInputs;
outMetadata.numOutputs = descriptor->numOutputs; outMetadata.numOutputs = descriptor->numOutputs;
for (int i = 0; i < descriptor->numInputs; i++) {
outMetadata.inputArgInfos[i] =
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
}
for (int i = 0; i < descriptor->numOutputs; i++) {
outMetadata.outputArgInfos[i] =
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
}
return success(); return success();
} }
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
const InputArgInfo &info) {
if (value.isTensor()) {
auto refTensor = value.toTensor();
// Don't bother checking shapes for unranked tensors
if (info.rank < 0)
return success();
if (refTensor->getRank() != info.rank)
return failure();
auto tensorExtents = refTensor->getExtents();
for (int i = 0; i < info.rank; i++) {
// If a dimension is dynamic, it is encoded as extent = -1
// and we should skip checking over that dimension
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
return failure();
}
}
return success();
}
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info) {
// Generic checks based on argType(s)
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
(value.isFloat() && info.argType != ArgType::kF32))
return failure();
if (value.isRef()) {
// Will need special error checking for ref-counted types
// Currently only f32 tensors are supported
if (value.isTensor()) {
auto refTensor = value.toTensor();
if (refTensor->getElementType() != ElementType::F32)
return failure();
} else {
assert(false && "Unsupported input type checking for Ref type");
}
}
return success();
}
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
constexpr int32_t kDynamicConstantShape = 100;
switch (info.argType) {
case ArgType::kTensor: {
// HACK: for dynamic dims the shape will be negative, so for now we are
// just going to create a tensor of size kDynamicConstantShape
std::array<int32_t, kMaxRank> tensorShape;
for (int i = 0; i < info.rank; i++) {
tensorShape[i] =
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
}
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
int numel = 1;
for (int i = 0; i < info.rank; i++)
numel *= shape[i];
void *data;
switch (info.elementType) {
case ElementType::F32: {
auto byteSize = numel * sizeof(float);
data = static_cast<void *>(aligned_alloc(32, byteSize));
memset(data, 0, byteSize);
return RtValue(Tensor::create(shape, ElementType::F32, data));
break;
}
default: { assert(false && "unknown output tensor type"); }
}
// The Tensor::create function will malloc and memcpy the data
// into the Tensor object, so after we need to free our
// temporary data buffer
assert(data && "data ptr must exist");
auto refTensor = Tensor::create(shape, ElementType::F32, data);
free(data);
return RtValue(refTensor);
}
case ArgType::kF32: {
return RtValue(-20.0f);
}
default: {
assert(false && "Don't know how to handle this artType");
return RtValue();
}
}
}

View File

@ -3,9 +3,17 @@
// CHECK: refbackrt.module_metadata // CHECK: refbackrt.module_metadata
refbackrt.module_metadata { refbackrt.module_metadata {
// CHECK: refbackrt.func_metadata // CHECK: refbackrt.func_metadata
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32} // TODO(brycearden): Encode unranked memrefs in the ABI
refbackrt.func_metadata {
funcName = @f,
numInputs = 1 : i32,
numOutputs = 0 : i32,
inputArgTypes = dense<1> : tensor<1xi32>,
inputElementTypes = dense<1> : tensor<1xi32>,
inputRanks = dense<-1> : tensor<1xi32>,
inputShapes = dense<1> : tensor<4xi32>}
} }
func @f(%arg0: memref<*xf32>) { func @f(%arg0: tensor<*xf32>) {
return return
} }

View File

@ -1,165 +0,0 @@
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// Test input/output arg marshaling.
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> ()
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_21]][1 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_23]][1 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_25]][2 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][2 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][2 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: llvm.return %[[VAL_37]] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>> to !llvm.ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: }
refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
}
func @inputs1results0(%arg0: memref<*xf32>) {
return
}
func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> {
return %arg0 : memref<*xf32>
}
func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) {
return %arg0, %arg0 : memref<*xf32>, memref<*xf32>
}
// -----
// Test emission of compiler runtime functions.
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00")
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
// CHECK-LABEL: llvm.func @calls_abort_if(
// CHECK-SAME: %[[VAL_0:.*]]: i1) {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @[[STRSYM]] : !llvm.ptr<array<4 x i8>>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<4 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_3:.*]], %[[VAL_2]]) : (i1, !llvm.ptr<i8>) -> ()
// CHECK: llvm.return
func @calls_abort_if(%arg0: i1) {
refbackrt.abort_if %arg0, "msg"
return
}

View File

@ -3,8 +3,14 @@
// Test module metadata. // Test module metadata.
// CHECK: refbackrt.module_metadata // CHECK: refbackrt.module_metadata
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32} // CHECK-NEXT: refbackrt.func_metadata
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32} // CHECK-SAME: funcName = @f_2inputs_0outputs
// CHECK-SAME: numInputs = 2
// CHECK-SAME: numOutputs = 0
// CHECK-NEXT: refbackrt.func_metadata
// CHECK-SAME: funcName = @f_1input_2outputs
// CHECK-SAME: numInputs = 1
// CHECK-SAME: numOutputs = 2
// This function only exists to test its metadata above. // This function only exists to test its metadata above.
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) { func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {

View File

@ -9,4 +9,3 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }

View File

@ -0,0 +1,26 @@
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke invalid_input_shape \
// RUN: -arg-value="dense<1.0> : tensor<2x2x2x2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s -check-prefix=ARG0-INVALID
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke invalid_input_shape_arg1 \
// RUN: -arg-value="dense<1.0> : tensor<1x2x5xf32>" \
// RUN: -arg-value="dense<1.0> : tensor<1x2x10xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s -check-prefix=ARG1-INVALID
// ARG0-INVALID: invoking 'invalid_input_shape': input shape mismatch (%arg0).
// ARG0-INVALID-SAME: actual (provided by user): (2x2x2x2)
// ARG0-INVALID-SAME: expected (from compiler): (1x2x3x4)
func @invalid_input_shape(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
return %arg0: tensor<1x2x3x4xf32>
}
// ARG1-INVALID: invoking 'invalid_input_shape_arg1': input shape mismatch (%arg1)
// ARG1-INVALID-SAME: actual (provided by user): (1x2x10)
// ARG1-INVALID-SAME: expected (from compiler): (1x4x?)
func @invalid_input_shape_arg1(%arg0: tensor<1x2x?xf32>, %arg1: tensor<1x4x?xf32>) {
return
}

View File

@ -0,0 +1,13 @@
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke expects_one_tensor \
// RUN: -arg-value="1.0 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
// CHECK-SAME: actual (provided by user): Float
// CHECK-SAME: expected (from compiler): kTensor
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -2,10 +2,21 @@
// RUN: -invoke scalar \ // RUN: -invoke scalar \
// RUN: -arg-value="dense<1.0> : tensor<f32>" \ // RUN: -arg-value="dense<1.0> : tensor<f32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \ // RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s // RUN: | FileCheck %s --check-prefix=SCALAR
// CHECK: output #0: dense<2.000000e+00> : tensor<f32> // RUN: npcomp-run-mlir %s \
// RUN: -invoke scalar_arg \
// RUN: -arg-value="2.5 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=SCALAR_ARG
// SCALAR: output #0: dense<2.000000e+00> : tensor<f32>
func @scalar(%arg0: tensor<f32>) -> tensor<f32> { func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
%0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32> %0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
// SCALAR_ARG: output #0: 2.500000e+00 : f32
func @scalar_arg(%arg0: f32) -> f32 {
return %arg0 : f32
}

View File

@ -58,6 +58,14 @@ convertAttrToTensor(Attribute attr) {
return make_string_error("unhandled argument"); return make_string_error("unhandled argument");
} }
static Expected<float> convertAttrToFloat(Attribute attr) {
auto type = attr.getType().dyn_cast<FloatType>();
if (!type)
return make_string_error("converting an argument to float that is not a FloatType");
auto floatAttr = attr.dyn_cast<FloatAttr>();
return floatAttr.getValue().convertToFloat();
}
static Expected<SmallVector<refbackrt::RtValue, 6>> static Expected<SmallVector<refbackrt::RtValue, 6>>
createInputs(ArrayRef<StringRef> argValues) { createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context; MLIRContext context;
@ -66,12 +74,22 @@ createInputs(ArrayRef<StringRef> argValues) {
auto attr = parseAttribute(argValue, &context); auto attr = parseAttribute(argValue, &context);
if (!attr) if (!attr)
return make_string_error(Twine("could not parse arg value: ") + argValue); return make_string_error(Twine("could not parse arg value: ") + argValue);
// TODO(brycearden): Handle multiple input types
auto expectedTensor = convertAttrToTensor(attr); auto attrType = attr.getType();
if (!expectedTensor)
return expectedTensor.takeError(); if (attrType.isa<RankedTensorType>()) {
ret.push_back(std::move(*expectedTensor)); auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor)
return expectedTensor.takeError();
ret.push_back(std::move(*expectedTensor));
} else if (attrType.isa<FloatType>()) {
auto expectedFloat = convertAttrToFloat(attr);
if (!expectedFloat)
return expectedFloat.takeError();
ret.push_back(refbackrt::RtValue(*expectedFloat));
}
} }
return ret; return ret;
} }
@ -92,34 +110,40 @@ static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor
return RankedTensorType::get(extents, elementType); return RankedTensorType::get(extents, elementType);
} }
static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor, static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
Builder &builder) { Builder &builder) {
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder); if (value.isTensor()) {
switch (tensor.getElementType()) { auto& tensor = *(value.toTensor());
case refbackrt::ElementType::F32: { RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
SmallVector<float, 100> values; switch (tensor.getElementType()) {
auto *basePtr = tensor.getData<float>(); case refbackrt::ElementType::F32: {
for (int i = 0, e = type.getNumElements(); i < e; i++) SmallVector<float, 100> values;
values.push_back(basePtr[i]); auto *basePtr = tensor.getData<float>();
return DenseFPElementsAttr::get(type, values); for (int i = 0, e = type.getNumElements(); i < e; i++)
} values.push_back(basePtr[i]);
return DenseFPElementsAttr::get(type, values);
}
}
} else if (value.isFloat()) {
return builder.getF32FloatAttr(value.toFloat());
} else {
assert(false && "could not convert value to mlir attribute");
} }
llvm_unreachable("unsupported dtype"); llvm_unreachable("unsupported dtype");
} }
static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) { static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
MLIRContext context; MLIRContext context;
Builder builder(&context); Builder builder(&context);
auto attr = convertToMLIRAttribute(tensor, builder); auto attr = convertToMLIRAttribute(value, builder);
attr.print(os); attr.print(os);
} }
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs, static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
llvm::raw_ostream &os) { llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) { for (auto output : llvm::enumerate(outputs)) {
assert(output.value().isTensor() && "only tensor outputs are supported.");
os << "output #" << output.index() << ": "; os << "output #" << output.index() << ": ";
printOutput(*output.value().toTensor().get(), os); printOutput(output.value(), os);
os << "\n"; os << "\n";
} }
} }
@ -150,9 +174,11 @@ Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
auto expectedInputs = createInputs(argValues); auto expectedInputs = createInputs(argValues);
if (!expectedInputs) if (!expectedInputs)
return expectedInputs.takeError(); return expectedInputs.takeError();
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs); auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
if (!expectedOutputs) if (!expectedOutputs)
return expectedOutputs.takeError(); return expectedOutputs.takeError();
auto outputs = std::move(*expectedOutputs); auto outputs = std::move(*expectedOutputs);
printOutputs(outputs, llvm::outs()); printOutputs(outputs, llvm::outs());
llvm::outs() << "SUCCESS\n"; llvm::outs() << "SUCCESS\n";