[refbackrt] Scalar arg support

* Adds f32 scalar argument support across the ABI boundary.
* Adds support for passing input type / shape information
  across the ABI boundary
* Adds support for parsing / creating input FloatAttr's in
  `npcomp-run-mlir`
pull/197/head
Bryce Arden 2021-03-10 11:53:03 -06:00 committed by Sean Silva
parent 703428eff4
commit 4591884d06
17 changed files with 1085 additions and 269 deletions

View File

@ -68,12 +68,47 @@ def Refbackrt_FuncMetadataOp
let description = [{
Runtime metadata for a single func.
TODO: Augment this with information for type/shape checking of arguments.
Contains type / shape information for arguments as described below:
* ArgType(s):
Integer value from `CompilerDataStructures.h` for each argument
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
* ElementType(s):
Certain input ArgType's also have an element type (e.g. Tensor<float>,
List<int>, etc.)
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
* Rank(s):
Integer value indicating the rank for each argument.
* Shape(s):
Flattened hyper-rectangular representation of the shapes for each argument.
Since each shape's size varies based on the Rank, we pad out the shapes
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
for details.
Shapes Example:
constexpr int kMaxRank = 6;
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
inputShapes = dense<...> : tensor<12xi32>
// 2 shapes with 6 elements each so that the LowerToLLVM pass
// where only the first `rank` values in each shape are valid.
//
// can update the struct(s) by just grabbing a pointer at
// %shape_ptr = %base + (kMaxRank * argIndex)
}];
let arguments = (ins
FlatSymbolRefAttr:$funcName,
I32Attr:$numInputs,
I32Attr:$numOutputs
I32Attr:$numOutputs,
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
OptionalAttr<I32ElementsAttr>:$inputRanks,
OptionalAttr<I32ElementsAttr>:$inputShapes,
// I32ElementsAttr:$inputIsStatic,
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
OptionalAttr<I32ElementsAttr>:$outputRanks,
OptionalAttr<I32ElementsAttr>:$outputShapes
//I32ElementsAttr:$outputIsStatic
);
let results = (outs);
let assemblyFormat = "attr-dict";

View File

@ -31,6 +31,8 @@ public:
return std::memcmp(ptr, other.ptr, length) == 0;
}
const char* str() { return ptr; }
private:
const char *ptr;
std::size_t length;

View File

@ -22,6 +22,7 @@
#define NPCOMP_RUNTIME_USERAPI_H
#include "npcomp/RefBackend/Runtime/Support.h"
#include <array>
#include <atomic>
#include <cstdlib>
@ -105,9 +106,11 @@ private:
// The available data types.
enum class ElementType : std::int32_t {
NONE,
F32,
};
std::int32_t getElementTypeByteSize(ElementType type);
StringRef getElementTypeAsStringRef(ElementType type);
// Representation of a tensor.
class Tensor : public RefTarget {
@ -124,6 +127,12 @@ public:
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
ElementType elementType, void *data);
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
// Same as `create`, but returns a raw pointer.
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
ElementType getElementType() const { return elementType; }
std::int32_t getRank() const { return rank; }
void *getData() const { return data; }
@ -169,6 +178,7 @@ private:
_(None) \
_(Bool) \
_(Int) \
_(Float) \
_(Double)
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
@ -193,15 +203,23 @@ struct RtValue final {
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
bool isInt() const { return Tag::Int == tag; }
bool toInt() const {
int64_t toInt() const {
assert(isInt());
return payload.asInt;
}
// Float
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
bool isFloat() const { return Tag::Float == tag; }
float toFloat() const {
assert(isFloat());
return payload.asFloat;
}
// Double
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
bool isDouble() const { return Tag::Double == tag; }
bool toDouble() const {
double toDouble() const {
assert(isDouble());
return payload.asDouble;
}
@ -227,6 +245,11 @@ struct RtValue final {
return false;
}
// Scalar
bool isScalar() const {
return isBool() || isInt() || isFloat() || isDouble();
}
// RtValue (downcast)
const RtValue &toRtValue() const { return *this; }
RtValue &toRtValue() { return *this; }
@ -298,6 +321,7 @@ private:
union Payload {
bool asBool;
int64_t asInt;
float asFloat;
double asDouble;
void *asVoidPtr;
};
@ -313,19 +337,72 @@ private:
// This is the main entry point that users interact with.
//===----------------------------------------------------------------------===//
enum class ArgType : std::uint32_t {
kNone = 0,
kTensor,
kF32,
kF64,
};
StringRef getArgTypeAsStringRef(ArgType type);
// Maximum rank supported across the ABI boundary
constexpr static int kMaxRank = 6;
struct InputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
};
struct OutputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
// TODO(brycearden): Add checks for whether output buffers alias to input
// buffers and populate field(s) here indicating that case
};
// Maximum input or output arity.
constexpr static int kMaxArity = 20;
// Metadata for a particular function.
// TODO: Add arg types.
struct FunctionMetadata {
std::int32_t numInputs;
std::int32_t numOutputs;
std::array<InputArgInfo, kMaxArity> inputArgInfos;
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
};
// Opaque forward declaration of module descriptor type. This is the type
// created by the compiler in the module binary.
struct ModuleDescriptor;
// Maximum input or output arity.
constexpr static int kMaxArity = 20;
// Verifies that the input RtValue arg types match what the user provides
// matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info);
// Verifies that the input RtValue shapes matches what the user provides
// matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueShapes(const RtValue &value,
const InputArgInfo &info);
// Creates an RtValue of the right type from the output metadata
// provided by the compiled module
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
// Low-level invocation API. The number of inputs and outputs should be correct
// and match the results of getMetadata.

View File

@ -54,6 +54,16 @@ static LogicalResult verify(FuncMetadataOp op) {
return op.emitError() << "must agree on number of inputs";
if (op.numOutputs() != func.getNumResults())
return op.emitError() << "must agree on number of outputs";
if (op.numInputs() > 0) {
if (op.numInputs() != op.inputArgTypes()->size()) {
return op.emitError() << "number of inputTypes must match number of inputs";
}
}
if (op.numOutputs() > 0) {
if (op.numOutputs() != op.outputArgTypes()->size())
return op.emitError() << "number of outputTypes must match number of outputs";
}
return success();
}

View File

@ -12,6 +12,8 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "npcomp/RefBackend/RefBackend.h"
#include <sstream>
using namespace refback;
using namespace mlir;
using llvm::Error;
@ -73,6 +75,22 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
}
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
static constexpr char *kDynamicDimAsString = "?";
std::stringstream ss;
ss << "(";
for (int i = 0; i < extents.size(); i++) {
if (extents[i] < 0)
ss << kDynamicDimAsString;
else
ss << extents[i];
if (i != extents.size() - 1)
ss << "x";
}
ss << ")";
return ss.str();
}
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::RtValue> inputs) {
@ -80,12 +98,47 @@ JITModule::invoke(llvm::StringRef functionName,
if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName));
SmallVector<refbackrt::RtValue, 6> outputs(
metadata.numOutputs);
SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) +
"': expected " + Twine(metadata.numInputs) +
" inputs");
// Verify user input types and shapes match what the compiler expects
for (int i = 0; i < metadata.numInputs; i++) {
auto &input = inputs[i];
auto &inputArgInfo = metadata.inputArgInfos[i];
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) +
"': input argument type mismatch. actual (provided by user): " +
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
Twine(i) + "). " + "actual (provided by user): " +
stringifyShape(input.toTensor()->getExtents()) +
", expected (from compiler): " +
stringifyShape(refbackrt::ArrayRef<int32_t>(
inputArgInfo.extents.data(), inputArgInfo.rank)));
}
// Create the correct output RtValue based on FuncMetadata,
// which contains the arg types (scalar, Tensor, etc.), element types (only
// applicable if not scalar) and shapes (also only applicable if not scalar)
//
// Currently we have to give each RtValue an output type so that we know
// how to pack / unpack the outputs properly across the ABI boundary in
// refbackrt::invoke. As a result, we can't just rely on the default
// construction of each output argument type (otherwise RtValue will have
// Tag::kNone) currently without passing the ArgInfo structs down to the
// Runtime level, so we deal with the output type creation here.
for (int i = 0; i < metadata.numOutputs; i++) {
outputs[i] = std::move(
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]));
}
refbackrt::invoke(
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));

View File

@ -34,25 +34,70 @@ using mlir::LLVM::LLVMVoidType;
// These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===//
// MaxRank that the refbackrt ABI lowering is capable of handling
// NOTE: This parameter must stay consistent with
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
static constexpr int kMaxRank = 6;
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 8));
}
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 32));
}
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
// Get the LLVM type for refbackrt::FuncDescriptor.
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(context,
{
// Name length.
IntegerType::get(context, 32),
// Name chars.
getInt8PointerType(context),
// Type-erased function pointer.
getInt8PointerType(context),
// Number of inputs.
IntegerType::get(context, 32),
// Number of outputs.
IntegerType::get(context, 32),
});
return LLVMStructType::getLiteral(
context, {
// Name length.
IntegerType::get(context, 32),
// Name chars.
getInt8PointerType(context),
// Type-erased function pointer.
getInt8PointerType(context),
// Number of inputs.
IntegerType::get(context, 32),
// Number of outputs.
IntegerType::get(context, 32),
// Argument descriptors
LLVMPointerType::get(getInputDescriptorTy(context)),
// Result Descriptors
LLVMPointerType::get(getOutputDescriptorTy(context)),
});
}
// Get the LLVM type for refbackrt::ModuleDescriptor.
@ -92,8 +137,8 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
// TODO: Deduplicate strings.
std::string msgNulTerminated = msg.getValue().str();
msgNulTerminated.push_back('\0');
auto arrayTy = LLVMArrayType::get(
IntegerType::get(module.getContext(), 8), msgNulTerminated.size());
auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
msgNulTerminated.size());
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(module.getBody());
@ -129,9 +174,9 @@ public:
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
op.msgAttr(), rewriter, op.getLoc());
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
auto c0 = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), IntegerType::get(context, 32),
rewriter.getI32IntegerAttr(0));
auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
IntegerType::get(context, 32),
rewriter.getI32IntegerAttr(0));
auto msg =
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
msgArray, ValueRange({c0, c0}));
@ -181,16 +226,181 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
for (auto funcMetadata : funcMetadatas) {
auto arrayTy =
LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
funcMetadata.funcName().size());
auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
funcMetadata.funcName().size());
std::string llvmSymbolName =
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
auto global = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
globalsByName[funcMetadata.funcName()] = global;
// Create constants for the input / output shapes
if (funcMetadata.inputShapes().hasValue()) {
auto i32ArrayInputSymbolName =
(Twine("__npcomp_internal_constant_input_shapes_") +
funcMetadata.funcName())
.str();
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
auto inputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayInputSymbolName,
/*value=*/funcMetadata.inputShapes().getValue());
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
}
if (funcMetadata.outputShapes().hasValue()) {
auto i32ArrayOutputSymbolName =
(Twine("__npcomp_internal_constant_output_shapes_") +
funcMetadata.funcName())
.str();
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
auto outputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayOutputSymbolName,
/*value=*/funcMetadata.outputShapes().getValue());
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
}
}
auto updateDescriptor = [&](Value &descriptor, Value value,
std::initializer_list<int32_t> position) {
descriptor = builder.create<LLVM::InsertValueOp>(
loc, descriptor, value,
/*position=*/builder.getI32ArrayAttr(position));
};
auto updateDescriptorWithI32Attr =
[&](Value &descriptor, Attribute attr,
std::initializer_list<int32_t> position) {
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
updateDescriptor(descriptor, constant, position);
};
// Create global input descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmInputSymbolName =
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
auto inputDescriptorArrayTy =
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, inputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value inputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
for (int i = 0; i < funcMetadata.numInputs(); i++) {
// Arg Type
if (!funcMetadata.inputArgTypes().hasValue())
funcMetadata.emitError()
<< "numInputs > 0 but there are no inputArgTypes?";
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto inputShapesType =
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.inputRanks()->getValue(i);
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
// Shape
// Each shape array is derived by offseting of kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, inputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
inputDescriptorsByName[funcMetadata.funcName()] =
std::move(inputDescriptorArrayGlobal);
}
// Create global output descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmOutputSymbolName =
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
auto outputDescriptorArrayTy =
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, outputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value outputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
for (int i = 0; i < funcMetadata.numOutputs(); i++) {
if (!funcMetadata.outputArgTypes().hasValue())
funcMetadata.emitError()
<< "numOutputs > 0 but there are no outputArgTypes?";
// Arg Type
updateDescriptorWithI32Attr(outputDescriptorArray,
funcMetadata.outputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto outputShapesType =
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.outputRanks()->getValue(i);
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
// Shapes
// Offset by kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, outputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
outputDescriptorsByName[funcMetadata.funcName()] =
outputDescriptorArrayGlobal;
}
// This must match FuncDescriptor in the runtime.
@ -201,31 +411,23 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
"__npcomp_func_descriptors",
/*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
// Build the initializer.
Value funcDescriptorArray =
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
auto updateDescriptor = [&](Value value,
std::initializer_list<int32_t> position) {
funcDescriptorArray = builder.create<LLVM::InsertValueOp>(
loc, funcDescriptorArray, value,
/*position=*/builder.getI32ArrayAttr(position));
};
auto updateDescriptorWithI32Attr =
[&](Attribute attr, std::initializer_list<int32_t> position) {
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
updateDescriptor(constant, position);
};
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
auto funcMetadata = funcMetadataAndIndex.value();
int32_t index = funcMetadataAndIndex.index();
// Name length.
updateDescriptorWithI32Attr(
funcDescriptorArray,
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
// Name chars.
@ -234,7 +436,7 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
auto funcNamePtr = builder.create<LLVM::GEPOp>(
loc, getInt8PointerType(builder.getContext()), funcNameArray,
ValueRange({c0, c0}));
updateDescriptor(funcNamePtr, {index, 1});
updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
// Function pointer.
//
@ -247,13 +449,31 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
loc, getInt8PointerType(builder.getContext()), funcAddress);
updateDescriptor(typeErasedFuncAddress, {index, 2});
updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
// Number of inputs.
updateDescriptorWithI32Attr(funcMetadata.numInputsAttr(), {index, 3});
updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numInputsAttr(), {index, 3});
// Number of outputs.
updateDescriptorWithI32Attr(funcMetadata.numOutputsAttr(), {index, 4});
updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numOutputsAttr(), {index, 4});
// Input descriptors
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, inputDescriptorsByName[funcMetadata.funcName()]);
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
inputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
// Output descriptors
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, outputDescriptorsByName[funcMetadata.funcName()]);
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
outputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
}
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
@ -379,6 +599,16 @@ static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
/*memorySpace=*/0));
}
static Type getDoubleType(MLIRContext *context) {
LLVMTypeConverter converter(context);
return converter.convertType(FloatType::getF64(context));
}
static Type getFloatType(MLIRContext *context) {
LLVMTypeConverter converter(context);
return converter.convertType(FloatType::getF32(context));
}
// Writes out the logical results of the wrapper function through the void**
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
// only supports a single return type (or void/no results), the logic here needs
@ -393,12 +623,18 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
return;
Value result = callToWrapped.getResult(0);
auto ty = result.getType();
// 1 logical result.
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
} else if (ty == getFloatType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
}
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
auto structType = ty.cast<LLVMStructType>();

View File

@ -21,6 +21,16 @@
using namespace mlir;
using namespace mlir::NPCOMP;
// Since input/output shapes are not hyper-rectangular we specify
// a maximum rank for each input shape such that shapes are padded
// out to kMaxRank at the ABI boundary. That way we can represent
// shapes using a traditional DenseElementsAttr.
//
// NOTE: When changing this parameter, also change the same `kMaxRank`
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
// stays consistent.
static constexpr int kMaxRank = 6;
// Get the type used to represent MemRefType `type` on ABI boundaries.
// For convenience we do a cast to MemRefType internally.
static Type getABIMemrefType(Type type) {
@ -38,7 +48,99 @@ static bool expressibleWithRefbackrtABI(FunctionType type) {
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
return llvm::all_of(
llvm::concat<const Type>(type.getInputs(), type.getResults()),
[](Type t) { return t.isa<MemRefType>(); });
[](Type t) {
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
});
}
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
// Must stay aligned with CompilerDataStructures::ABIArgType enum
static uint32_t getIntReprForABIType(Type type) {
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
return 1;
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 2;
case 64:
return 3;
default:
assert(false && "Unsupported float bit width");
}
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
}
// assert(false && "couldn't get IntReprForABIType");
return -1;
}
// Must stay aligned with CompilerDataStructures::ABIElementType enum
static uint32_t getIntReprForABIElementType(Type type) {
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
auto elemTy = shapedTy.getElementType();
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 1;
default:
assert(false && "Unsupported tensor element type");
}
}
}
return 0;
}
static SmallVector<int32_t, kMaxRank>
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasRank()) {
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
auto shape = shapedType.getShape();
auto shapeRank = shapedType.getRank();
if (shapeRank <= maxRank) {
SmallVector<int32_t, kMaxRank> extendedShape;
// Push back all the values of the shape
for (auto extentAndIndex : llvm::enumerate(shape)) {
auto extent = extentAndIndex.value();
auto index = extentAndIndex.index();
if (shapedType.isDynamic(index)) {
extendedShape.push_back(-1);
} else {
extendedShape.push_back(extent);
}
}
// Pad whatever is left so we have even vectors
auto padRank = maxRank - shapeRank;
for (int i = 0; i < padRank; i++)
extendedShape.push_back(0xDEAD'BEEF);
return extendedShape;
} else {
assert(false && "unsupported rank");
}
}
// Represent Scalar's as all 1's.
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
int32_t getRankForType(Type type) {
// Returns a rank of -1 if the tensor is unranked
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasRank() ? shapedType.getRank() : -1;
}
return 0;
}
uint32_t hasStaticShape(Type type) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasStaticShape() ? 1 : 0;
}
// Assume scalars and non-shaped type things are static
return 1;
}
static LogicalResult createModuleMetadata(ModuleOp module) {
@ -59,10 +161,131 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
}
// TODO: Add richer information here such as expected shapes and element
// types.
builder.create<refbackrt::FuncMetadataOp>(
func.getLoc(), builder.getSymbolRefAttr(func.getName()),
builder.getI32IntegerAttr(func.getNumArguments()),
builder.getI32IntegerAttr(func.getNumResults()));
SmallVector<uint32_t, 6> inputABIArgTypes;
SmallVector<uint32_t, 6> inputABIElementTypes;
SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
SmallVector<uint32_t, 6> inputABIRanks;
// SmallVector<uint32_t, 6> inputIsStatic;
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
inputABIShapes.push_back(
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
inputABIRanks.push_back(getRankForType(inputArgType));
// inputIsStatic.push_back(hasStaticShape(inputArgType));
}
SmallVector<uint32_t, 6> outputABIArgTypes;
SmallVector<uint32_t, 6> outputABIElementTypes;
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
SmallVector<uint32_t, 6> outputABIRanks;
SmallVector<uint32_t, 6> outputIsStatic;
for (const auto &outputArgType : func.getCallableResults()) {
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
outputABIElementTypes.push_back(
getIntReprForABIElementType(outputArgType));
outputABIShapes.push_back(
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
outputABIRanks.push_back(getRankForType(outputArgType));
// outputIsStatic.push_back(hasStaticShape(outputArgType));
}
auto i32Type = builder.getIntegerType(32);
auto inputABIDataType =
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
auto inputABIElementType =
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
auto inputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
kMaxRank},
i32Type);
auto inputABIRanksType =
RankedTensorType::get(inputABIRanks.size(), i32Type);
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
// i32Type);
auto outputABIDataType =
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
auto outputABIElementType =
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
auto outputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
kMaxRank},
i32Type);
auto outputABIRanksType =
RankedTensorType::get(outputABIRanks.size(), i32Type);
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
// i32Type);
// TODO(brycearden): I'm sure there's a cleaner way to do this
auto flattenABIShapes =
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
SmallVector<int32_t, 32> ret;
for (auto &shape : shapes)
for (auto &dim : shape)
ret.push_back(dim);
return ret;
};
SmallVector<NamedAttribute, 16> namedAttrs;
// Add attributes that are valid for every func (funcName, numInputs,
// numOutputs)
namedAttrs.push_back(
std::make_pair(Identifier::get("funcName", module.getContext()),
builder.getSymbolRefAttr(func.getName())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numInputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumArguments())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numOutputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumResults())));
if (inputABIArgTypes.size()) {
// Only add input information if there are inputs
namedAttrs.push_back(std::make_pair(
Identifier::get("inputArgTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIDataType,
llvm::makeArrayRef(inputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputElementTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIElementType,
llvm::makeArrayRef(inputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputRanks", func.getContext()),
DenseIntElementsAttr::get(inputABIRanksType,
llvm::makeArrayRef(inputABIRanks))));
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
namedAttrs.push_back(std::make_pair(
Identifier::get("inputShapes", func.getContext()),
DenseIntElementsAttr::get(
inputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
}
if (outputABIArgTypes.size()) {
// Only add output information if there are outptus
namedAttrs.push_back(std::make_pair(
Identifier::get("outputArgTypes", func.getContext()),
DenseIntElementsAttr::get(outputABIDataType,
llvm::makeArrayRef(outputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputElementTypes", func.getContext()),
DenseIntElementsAttr::get(
outputABIElementType,
llvm::makeArrayRef(outputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputRanks", func.getContext()),
DenseIntElementsAttr::get(outputABIRanksType,
llvm::makeArrayRef(outputABIRanks))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputShapes", func.getContext()),
DenseIntElementsAttr::get(
outputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
}
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{}, namedAttrs);
if (!expressibleWithRefbackrtABI(func.getType()))
return func.emitError() << "func not expressible with refbackrt ABI";

View File

@ -23,6 +23,40 @@ namespace refbackrt {
// LowerToLLVM.cpp for more details.
typedef void ABIFunc(void **, void **);
enum class ABIArgType : std::uint32_t {
kNone = 0,
kMemref,
kF32,
kF64,
};
enum class ABIElementType : std::uint32_t {
kNone = 0,
kF32,
};
struct InputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
// std::int32_t isStatic;
};
struct OutputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
//std::int32_t isStatic;
};
struct FuncDescriptor {
// The length of the function name.
std::int32_t nameLen;
@ -35,9 +69,9 @@ struct FuncDescriptor {
std::int32_t numInputs;
// The number of outputs of the function.
std::int32_t numOutputs;
// TODO: Add arg/result descriptors and other metadata.
// With those descriptors we can do type and shape checking for each
// argument.
// TODO: Add shape checking to arg / result descriptor(s)
InputDescriptor *inputDescriptors;
OutputDescriptor *outputDescriptors;
};
// The top-level entry point of the module metadata emitted by the

View File

@ -118,12 +118,38 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
switch (type) {
case ElementType::NONE:
return 0;
case ElementType::F32:
return 4;
}
llvm_unreachable("unsupported dtype");
}
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
switch (type) {
case ElementType::NONE:
return "NONE";
case ElementType::F32:
return "F32";
}
llvm_unreachable("unsupported element type string");
}
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
switch (type) {
case ArgType::kNone:
return "kNone";
case ArgType::kTensor:
return "kTensor";
case ArgType::kF32:
return "kF32";
case ArgType::kF64:
return "kF64";
}
llvm_unreachable("unsupported arg type string");
}
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
void *data) {
return Ref<Tensor>(createRaw(extents, type, data));
@ -192,10 +218,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
// TODO: Avoid the deep copy. It makes the later lifetime management code
// more complex though (and maybe impossible given the current abstractions).
for (int i = 0, e = inputs.size(); i < e; i++) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
}
//
// Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
// one LLVM/C ABI argument to the underlying function.
@ -204,16 +227,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// "explode" the unranked memref descriptors on the underlying function
// into separate arguments for the rank and pointer-to-descriptor.
for (int i = 0, e = inputs.size(); i < e; i++) {
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
auto idx = 2 * i;
if (inputs[i].isTensor()) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
} else if (inputs[i].isScalar()) {
packedInputs[idx] = ToVoidPtr(&inputs[i]);
} else {
assert(false && "unsupported input RtValue type");
}
}
// Create a type-erased list of "packed output" to pass to the
// LLVM/C ABI wrapper function.
//
// Due to how StandardToLLVM lowering works, each packedOutput pointer
// corresponds to a single UnrankedMemref (not "exploded").
for (int i = 0, e = outputs.size(); i < e; i++) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
if (outputs[i].isTensor()) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
} else if (outputs[i].isScalar()) {
packedOutputs[i] = ToVoidPtr(&outputs[i]);
}
}
// Actually invoke the function!
@ -223,11 +260,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// TODO: Avoid needing to make a deep copy.
for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata.
auto elementType = ElementType::F32;
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = RtValue(Ref<Tensor>(tensor));
if (outputs[i].isTensor()) {
auto elementType = ElementType::F32;
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = RtValue(Ref<Tensor>(tensor));
} else if (outputs[i].isFloat()) {
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
}
}
// Now, we just need to free all the UnrankedMemref's that we created.
@ -239,24 +280,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Free the output buffers.
for (int i = 0, e = outputs.size(); i < e; i++) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying
// malloc allocation. Do a linear scan to see if any of the previously
// deallocated buffers already freed this pointer.
bool bufferNeedsFreeing = true;
for (int j = 0; j < i; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
if (outputs[i].isRef()) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying
// malloc allocation. Do a linear scan to see if any of the previously
// deallocated buffers already freed this pointer.
bool bufferNeedsFreeing = true;
for (int j = 0; j < i; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
// Free the input buffers.
for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true;
for (int j = 0, je = outputs.size(); j < je; j++) {
if (!outputs[j].isRef())
continue;
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
@ -274,6 +321,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
// Free the output descriptors.
for (int i = 0, e = outputs.size(); i < e; i++) {
if (!outputs[i].isRef())
continue;
// The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's.
@ -281,10 +330,81 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
}
// Free the input descriptors.
for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
std::free(inputUnrankedMemrefs[i].descriptor);
}
}
static InputArgInfo
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
InputArgInfo ret;
// Set arg / element types accordingly
switch (inputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
default:
assert(false && "need to update external internal map");
}
// Extract shape information
ret.rank = inputDescriptor.rank;
for (int i = 0; i < inputDescriptor.rank; i++) {
ret.extents[i] = inputDescriptor.extents[i];
}
return ret;
}
static OutputArgInfo
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
OutputArgInfo ret;
// Set arg / element types accordingly
switch (outputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
default:
assert(false && "need to update external internal map");
}
// Extract shape information
ret.rank = outputDescriptor.rank;
for (int i = 0; i < outputDescriptor.rank; i++) {
ret.extents[i] = outputDescriptor.extents[i];
}
return ret;
}
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName,
FunctionMetadata &outMetadata) {
@ -293,5 +413,107 @@ LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
return failure();
outMetadata.numInputs = descriptor->numInputs;
outMetadata.numOutputs = descriptor->numOutputs;
for (int i = 0; i < descriptor->numInputs; i++) {
outMetadata.inputArgInfos[i] =
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
}
for (int i = 0; i < descriptor->numOutputs; i++) {
outMetadata.outputArgInfos[i] =
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
}
return success();
}
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
const InputArgInfo &info) {
if (value.isTensor()) {
auto refTensor = value.toTensor();
// Don't bother checking shapes for unranked tensors
if (info.rank < 0)
return success();
if (refTensor->getRank() != info.rank)
return failure();
auto tensorExtents = refTensor->getExtents();
for (int i = 0; i < info.rank; i++) {
// If a dimension is dynamic, it is encoded as extent = -1
// and we should skip checking over that dimension
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
return failure();
}
}
return success();
}
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info) {
// Generic checks based on argType(s)
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
(value.isFloat() && info.argType != ArgType::kF32))
return failure();
if (value.isRef()) {
// Will need special error checking for ref-counted types
// Currently only f32 tensors are supported
if (value.isTensor()) {
auto refTensor = value.toTensor();
if (refTensor->getElementType() != ElementType::F32)
return failure();
} else {
assert(false && "Unsupported input type checking for Ref type");
}
}
return success();
}
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
constexpr int32_t kDynamicConstantShape = 100;
switch (info.argType) {
case ArgType::kTensor: {
// HACK: for dynamic dims the shape will be negative, so for now we are
// just going to create a tensor of size kDynamicConstantShape
std::array<int32_t, kMaxRank> tensorShape;
for (int i = 0; i < info.rank; i++) {
tensorShape[i] =
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
}
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
int numel = 1;
for (int i = 0; i < info.rank; i++)
numel *= shape[i];
void *data;
switch (info.elementType) {
case ElementType::F32: {
auto byteSize = numel * sizeof(float);
data = static_cast<void *>(aligned_alloc(32, byteSize));
memset(data, 0, byteSize);
return RtValue(Tensor::create(shape, ElementType::F32, data));
break;
}
default: { assert(false && "unknown output tensor type"); }
}
// The Tensor::create function will malloc and memcpy the data
// into the Tensor object, so after we need to free our
// temporary data buffer
assert(data && "data ptr must exist");
auto refTensor = Tensor::create(shape, ElementType::F32, data);
free(data);
return RtValue(refTensor);
}
case ArgType::kF32: {
return RtValue(-20.0f);
}
default: {
assert(false && "Don't know how to handle this artType");
return RtValue();
}
}
}

View File

@ -3,9 +3,17 @@
// CHECK: refbackrt.module_metadata
refbackrt.module_metadata {
// CHECK: refbackrt.func_metadata
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
// TODO(brycearden): Encode unranked memrefs in the ABI
refbackrt.func_metadata {
funcName = @f,
numInputs = 1 : i32,
numOutputs = 0 : i32,
inputArgTypes = dense<1> : tensor<1xi32>,
inputElementTypes = dense<1> : tensor<1xi32>,
inputRanks = dense<-1> : tensor<1xi32>,
inputShapes = dense<1> : tensor<4xi32>}
}
func @f(%arg0: memref<*xf32>) {
func @f(%arg0: tensor<*xf32>) {
return
}

View File

@ -1,165 +0,0 @@
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// Test input/output arg marshaling.
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.return
// CHECK: }
/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> ()
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_21]][1 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_23]][1 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(15 : i32) : i32
// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_25]][2 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][2 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][2 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: llvm.return %[[VAL_37]] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: }
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>> to !llvm.ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: }
refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
}
func @inputs1results0(%arg0: memref<*xf32>) {
return
}
func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> {
return %arg0 : memref<*xf32>
}
func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) {
return %arg0, %arg0 : memref<*xf32>, memref<*xf32>
}
// -----
// Test emission of compiler runtime functions.
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00")
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
// CHECK-LABEL: llvm.func @calls_abort_if(
// CHECK-SAME: %[[VAL_0:.*]]: i1) {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @[[STRSYM]] : !llvm.ptr<array<4 x i8>>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<4 x i8>>, i32, i32) -> !llvm.ptr<i8>
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_3:.*]], %[[VAL_2]]) : (i1, !llvm.ptr<i8>) -> ()
// CHECK: llvm.return
func @calls_abort_if(%arg0: i1) {
refbackrt.abort_if %arg0, "msg"
return
}

View File

@ -3,8 +3,14 @@
// Test module metadata.
// CHECK: refbackrt.module_metadata
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
// CHECK-NEXT: refbackrt.func_metadata
// CHECK-SAME: funcName = @f_2inputs_0outputs
// CHECK-SAME: numInputs = 2
// CHECK-SAME: numOutputs = 0
// CHECK-NEXT: refbackrt.func_metadata
// CHECK-SAME: funcName = @f_1input_2outputs
// CHECK-SAME: numInputs = 1
// CHECK-SAME: numOutputs = 2
// This function only exists to test its metadata above.
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {

View File

@ -8,5 +8,4 @@
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
}

View File

@ -0,0 +1,26 @@
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke invalid_input_shape \
// RUN: -arg-value="dense<1.0> : tensor<2x2x2x2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s -check-prefix=ARG0-INVALID
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke invalid_input_shape_arg1 \
// RUN: -arg-value="dense<1.0> : tensor<1x2x5xf32>" \
// RUN: -arg-value="dense<1.0> : tensor<1x2x10xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s -check-prefix=ARG1-INVALID
// ARG0-INVALID: invoking 'invalid_input_shape': input shape mismatch (%arg0).
// ARG0-INVALID-SAME: actual (provided by user): (2x2x2x2)
// ARG0-INVALID-SAME: expected (from compiler): (1x2x3x4)
func @invalid_input_shape(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
return %arg0: tensor<1x2x3x4xf32>
}
// ARG1-INVALID: invoking 'invalid_input_shape_arg1': input shape mismatch (%arg1)
// ARG1-INVALID-SAME: actual (provided by user): (1x2x10)
// ARG1-INVALID-SAME: expected (from compiler): (1x4x?)
func @invalid_input_shape_arg1(%arg0: tensor<1x2x?xf32>, %arg1: tensor<1x4x?xf32>) {
return
}

View File

@ -0,0 +1,13 @@
// RUN: not npcomp-run-mlir %s \
// RUN: -invoke expects_one_tensor \
// RUN: -arg-value="1.0 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
// CHECK-SAME: actual (provided by user): Float
// CHECK-SAME: expected (from compiler): kTensor
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -2,10 +2,21 @@
// RUN: -invoke scalar \
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// RUN: | FileCheck %s --check-prefix=SCALAR
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
// RUN: npcomp-run-mlir %s \
// RUN: -invoke scalar_arg \
// RUN: -arg-value="2.5 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s --check-prefix=SCALAR_ARG
// SCALAR: output #0: dense<2.000000e+00> : tensor<f32>
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
%0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// SCALAR_ARG: output #0: 2.500000e+00 : f32
func @scalar_arg(%arg0: f32) -> f32 {
return %arg0 : f32
}

View File

@ -58,6 +58,14 @@ convertAttrToTensor(Attribute attr) {
return make_string_error("unhandled argument");
}
static Expected<float> convertAttrToFloat(Attribute attr) {
auto type = attr.getType().dyn_cast<FloatType>();
if (!type)
return make_string_error("converting an argument to float that is not a FloatType");
auto floatAttr = attr.dyn_cast<FloatAttr>();
return floatAttr.getValue().convertToFloat();
}
static Expected<SmallVector<refbackrt::RtValue, 6>>
createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context;
@ -66,12 +74,22 @@ createInputs(ArrayRef<StringRef> argValues) {
auto attr = parseAttribute(argValue, &context);
if (!attr)
return make_string_error(Twine("could not parse arg value: ") + argValue);
// TODO(brycearden): Handle multiple input types
auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor)
return expectedTensor.takeError();
ret.push_back(std::move(*expectedTensor));
auto attrType = attr.getType();
if (attrType.isa<RankedTensorType>()) {
auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor)
return expectedTensor.takeError();
ret.push_back(std::move(*expectedTensor));
} else if (attrType.isa<FloatType>()) {
auto expectedFloat = convertAttrToFloat(attr);
if (!expectedFloat)
return expectedFloat.takeError();
ret.push_back(refbackrt::RtValue(*expectedFloat));
}
}
return ret;
}
@ -92,34 +110,40 @@ static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor
return RankedTensorType::get(extents, elementType);
}
static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor,
static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
Builder &builder) {
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
switch (tensor.getElementType()) {
case refbackrt::ElementType::F32: {
SmallVector<float, 100> values;
auto *basePtr = tensor.getData<float>();
for (int i = 0, e = type.getNumElements(); i < e; i++)
values.push_back(basePtr[i]);
return DenseFPElementsAttr::get(type, values);
}
if (value.isTensor()) {
auto& tensor = *(value.toTensor());
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
switch (tensor.getElementType()) {
case refbackrt::ElementType::F32: {
SmallVector<float, 100> values;
auto *basePtr = tensor.getData<float>();
for (int i = 0, e = type.getNumElements(); i < e; i++)
values.push_back(basePtr[i]);
return DenseFPElementsAttr::get(type, values);
}
}
} else if (value.isFloat()) {
return builder.getF32FloatAttr(value.toFloat());
} else {
assert(false && "could not convert value to mlir attribute");
}
llvm_unreachable("unsupported dtype");
}
static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
MLIRContext context;
Builder builder(&context);
auto attr = convertToMLIRAttribute(tensor, builder);
auto attr = convertToMLIRAttribute(value, builder);
attr.print(os);
}
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) {
assert(output.value().isTensor() && "only tensor outputs are supported.");
os << "output #" << output.index() << ": ";
printOutput(*output.value().toTensor().get(), os);
printOutput(output.value(), os);
os << "\n";
}
}
@ -150,9 +174,11 @@ Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
auto expectedInputs = createInputs(argValues);
if (!expectedInputs)
return expectedInputs.takeError();
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
if (!expectedOutputs)
return expectedOutputs.takeError();
auto outputs = std::move(*expectedOutputs);
printOutputs(outputs, llvm::outs());
llvm::outs() << "SUCCESS\n";