mirror of https://github.com/llvm/torch-mlir
[refbackrt] Scalar arg support
* Adds f32 scalar argument support across the ABI boundary. * Adds support for passing input type / shape information across the ABI boundary * Adds support for parsing / creating input FloatAttr's in `npcomp-run-mlir`pull/197/head
parent
703428eff4
commit
4591884d06
|
@ -68,12 +68,47 @@ def Refbackrt_FuncMetadataOp
|
|||
let description = [{
|
||||
Runtime metadata for a single func.
|
||||
|
||||
TODO: Augment this with information for type/shape checking of arguments.
|
||||
Contains type / shape information for arguments as described below:
|
||||
|
||||
* ArgType(s):
|
||||
Integer value from `CompilerDataStructures.h` for each argument
|
||||
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
|
||||
* ElementType(s):
|
||||
Certain input ArgType's also have an element type (e.g. Tensor<float>,
|
||||
List<int>, etc.)
|
||||
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
|
||||
* Rank(s):
|
||||
Integer value indicating the rank for each argument.
|
||||
* Shape(s):
|
||||
Flattened hyper-rectangular representation of the shapes for each argument.
|
||||
Since each shape's size varies based on the Rank, we pad out the shapes
|
||||
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
|
||||
for details.
|
||||
|
||||
Shapes Example:
|
||||
constexpr int kMaxRank = 6;
|
||||
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
|
||||
inputShapes = dense<...> : tensor<12xi32>
|
||||
// 2 shapes with 6 elements each so that the LowerToLLVM pass
|
||||
// where only the first `rank` values in each shape are valid.
|
||||
//
|
||||
// can update the struct(s) by just grabbing a pointer at
|
||||
// %shape_ptr = %base + (kMaxRank * argIndex)
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$funcName,
|
||||
I32Attr:$numInputs,
|
||||
I32Attr:$numOutputs
|
||||
I32Attr:$numOutputs,
|
||||
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$inputRanks,
|
||||
OptionalAttr<I32ElementsAttr>:$inputShapes,
|
||||
// I32ElementsAttr:$inputIsStatic,
|
||||
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$outputRanks,
|
||||
OptionalAttr<I32ElementsAttr>:$outputShapes
|
||||
//I32ElementsAttr:$outputIsStatic
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
|
|
@ -31,6 +31,8 @@ public:
|
|||
return std::memcmp(ptr, other.ptr, length) == 0;
|
||||
}
|
||||
|
||||
const char* str() { return ptr; }
|
||||
|
||||
private:
|
||||
const char *ptr;
|
||||
std::size_t length;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#define NPCOMP_RUNTIME_USERAPI_H
|
||||
|
||||
#include "npcomp/RefBackend/Runtime/Support.h"
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
|
||||
|
@ -105,9 +106,11 @@ private:
|
|||
|
||||
// The available data types.
|
||||
enum class ElementType : std::int32_t {
|
||||
NONE,
|
||||
F32,
|
||||
};
|
||||
std::int32_t getElementTypeByteSize(ElementType type);
|
||||
StringRef getElementTypeAsStringRef(ElementType type);
|
||||
|
||||
// Representation of a tensor.
|
||||
class Tensor : public RefTarget {
|
||||
|
@ -124,6 +127,12 @@ public:
|
|||
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
|
||||
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
// Same as `create`, but returns a raw pointer.
|
||||
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
|
||||
ElementType getElementType() const { return elementType; }
|
||||
std::int32_t getRank() const { return rank; }
|
||||
void *getData() const { return data; }
|
||||
|
@ -169,6 +178,7 @@ private:
|
|||
_(None) \
|
||||
_(Bool) \
|
||||
_(Int) \
|
||||
_(Float) \
|
||||
_(Double)
|
||||
|
||||
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
||||
|
@ -193,15 +203,23 @@ struct RtValue final {
|
|||
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
||||
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
||||
bool isInt() const { return Tag::Int == tag; }
|
||||
bool toInt() const {
|
||||
int64_t toInt() const {
|
||||
assert(isInt());
|
||||
return payload.asInt;
|
||||
}
|
||||
|
||||
// Float
|
||||
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
|
||||
bool isFloat() const { return Tag::Float == tag; }
|
||||
float toFloat() const {
|
||||
assert(isFloat());
|
||||
return payload.asFloat;
|
||||
}
|
||||
|
||||
// Double
|
||||
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
||||
bool isDouble() const { return Tag::Double == tag; }
|
||||
bool toDouble() const {
|
||||
double toDouble() const {
|
||||
assert(isDouble());
|
||||
return payload.asDouble;
|
||||
}
|
||||
|
@ -227,6 +245,11 @@ struct RtValue final {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Scalar
|
||||
bool isScalar() const {
|
||||
return isBool() || isInt() || isFloat() || isDouble();
|
||||
}
|
||||
|
||||
// RtValue (downcast)
|
||||
const RtValue &toRtValue() const { return *this; }
|
||||
RtValue &toRtValue() { return *this; }
|
||||
|
@ -298,6 +321,7 @@ private:
|
|||
union Payload {
|
||||
bool asBool;
|
||||
int64_t asInt;
|
||||
float asFloat;
|
||||
double asDouble;
|
||||
void *asVoidPtr;
|
||||
};
|
||||
|
@ -313,19 +337,72 @@ private:
|
|||
// This is the main entry point that users interact with.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
enum class ArgType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kTensor,
|
||||
kF32,
|
||||
kF64,
|
||||
};
|
||||
StringRef getArgTypeAsStringRef(ArgType type);
|
||||
|
||||
// Maximum rank supported across the ABI boundary
|
||||
constexpr static int kMaxRank = 6;
|
||||
|
||||
struct InputArgInfo {
|
||||
// What type of argument this is
|
||||
ArgType argType;
|
||||
// Certain arg types also have an element type
|
||||
ElementType elementType;
|
||||
std::int32_t rank;
|
||||
std::array<std::int32_t, kMaxRank> extents;
|
||||
};
|
||||
|
||||
struct OutputArgInfo {
|
||||
// What type of argument this is
|
||||
ArgType argType;
|
||||
// Certain arg types also have an element type
|
||||
ElementType elementType;
|
||||
std::int32_t rank;
|
||||
std::array<std::int32_t, kMaxRank> extents;
|
||||
// TODO(brycearden): Add checks for whether output buffers alias to input
|
||||
// buffers and populate field(s) here indicating that case
|
||||
};
|
||||
|
||||
// Maximum input or output arity.
|
||||
constexpr static int kMaxArity = 20;
|
||||
|
||||
// Metadata for a particular function.
|
||||
// TODO: Add arg types.
|
||||
struct FunctionMetadata {
|
||||
std::int32_t numInputs;
|
||||
std::int32_t numOutputs;
|
||||
|
||||
std::array<InputArgInfo, kMaxArity> inputArgInfos;
|
||||
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
|
||||
};
|
||||
|
||||
// Opaque forward declaration of module descriptor type. This is the type
|
||||
// created by the compiler in the module binary.
|
||||
struct ModuleDescriptor;
|
||||
|
||||
// Maximum input or output arity.
|
||||
constexpr static int kMaxArity = 20;
|
||||
// Verifies that the input RtValue arg types match what the user provides
|
||||
// matches the types we expect from the descriptors emitted by the
|
||||
// compiler.
|
||||
//
|
||||
// Returns failure if the input type(s) are not valid
|
||||
LogicalResult checkRtValueArgTypes(const RtValue &value,
|
||||
const InputArgInfo &info);
|
||||
|
||||
// Verifies that the input RtValue shapes matches what the user provides
|
||||
// matches the types we expect from the descriptors emitted by the
|
||||
// compiler.
|
||||
//
|
||||
// Returns failure if the input type(s) are not valid
|
||||
LogicalResult checkRtValueShapes(const RtValue &value,
|
||||
const InputArgInfo &info);
|
||||
|
||||
// Creates an RtValue of the right type from the output metadata
|
||||
// provided by the compiled module
|
||||
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
|
||||
|
||||
// Low-level invocation API. The number of inputs and outputs should be correct
|
||||
// and match the results of getMetadata.
|
||||
|
|
|
@ -54,6 +54,16 @@ static LogicalResult verify(FuncMetadataOp op) {
|
|||
return op.emitError() << "must agree on number of inputs";
|
||||
if (op.numOutputs() != func.getNumResults())
|
||||
return op.emitError() << "must agree on number of outputs";
|
||||
|
||||
if (op.numInputs() > 0) {
|
||||
if (op.numInputs() != op.inputArgTypes()->size()) {
|
||||
return op.emitError() << "number of inputTypes must match number of inputs";
|
||||
}
|
||||
}
|
||||
if (op.numOutputs() > 0) {
|
||||
if (op.numOutputs() != op.outputArgTypes()->size())
|
||||
return op.emitError() << "number of outputTypes must match number of outputs";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
using namespace refback;
|
||||
using namespace mlir;
|
||||
using llvm::Error;
|
||||
|
@ -73,6 +75,22 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
|||
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||
}
|
||||
|
||||
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
|
||||
static constexpr char *kDynamicDimAsString = "?";
|
||||
std::stringstream ss;
|
||||
ss << "(";
|
||||
for (int i = 0; i < extents.size(); i++) {
|
||||
if (extents[i] < 0)
|
||||
ss << kDynamicDimAsString;
|
||||
else
|
||||
ss << extents[i];
|
||||
if (i != extents.size() - 1)
|
||||
ss << "x";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||
JITModule::invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
||||
|
@ -80,12 +98,47 @@ JITModule::invoke(llvm::StringRef functionName,
|
|||
if (refbackrt::failed(refbackrt::getMetadata(
|
||||
descriptor, toRefbackrt(functionName), metadata)))
|
||||
return make_string_error("unknown function: " + Twine(functionName));
|
||||
SmallVector<refbackrt::RtValue, 6> outputs(
|
||||
metadata.numOutputs);
|
||||
SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
|
||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||
return make_string_error("invoking '" + Twine(functionName) +
|
||||
"': expected " + Twine(metadata.numInputs) +
|
||||
" inputs");
|
||||
|
||||
// Verify user input types and shapes match what the compiler expects
|
||||
for (int i = 0; i < metadata.numInputs; i++) {
|
||||
auto &input = inputs[i];
|
||||
auto &inputArgInfo = metadata.inputArgInfos[i];
|
||||
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
|
||||
return make_string_error(
|
||||
"invoking '" + Twine(functionName) +
|
||||
"': input argument type mismatch. actual (provided by user): " +
|
||||
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
|
||||
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
|
||||
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
|
||||
return make_string_error(
|
||||
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
|
||||
Twine(i) + "). " + "actual (provided by user): " +
|
||||
stringifyShape(input.toTensor()->getExtents()) +
|
||||
", expected (from compiler): " +
|
||||
stringifyShape(refbackrt::ArrayRef<int32_t>(
|
||||
inputArgInfo.extents.data(), inputArgInfo.rank)));
|
||||
}
|
||||
|
||||
// Create the correct output RtValue based on FuncMetadata,
|
||||
// which contains the arg types (scalar, Tensor, etc.), element types (only
|
||||
// applicable if not scalar) and shapes (also only applicable if not scalar)
|
||||
//
|
||||
// Currently we have to give each RtValue an output type so that we know
|
||||
// how to pack / unpack the outputs properly across the ABI boundary in
|
||||
// refbackrt::invoke. As a result, we can't just rely on the default
|
||||
// construction of each output argument type (otherwise RtValue will have
|
||||
// Tag::kNone) currently without passing the ArgInfo structs down to the
|
||||
// Runtime level, so we deal with the output type creation here.
|
||||
for (int i = 0; i < metadata.numOutputs; i++) {
|
||||
outputs[i] = std::move(
|
||||
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]));
|
||||
}
|
||||
|
||||
refbackrt::invoke(
|
||||
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
|
||||
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
||||
|
|
|
@ -34,25 +34,70 @@ using mlir::LLVM::LLVMVoidType;
|
|||
// These correspond to the types in CompilerDataStructures.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MaxRank that the refbackrt ABI lowering is capable of handling
|
||||
// NOTE: This parameter must stay consistent with
|
||||
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
|
||||
static constexpr int kMaxRank = 6;
|
||||
|
||||
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
|
||||
return LLVMPointerType::get(IntegerType::get(context, 8));
|
||||
}
|
||||
|
||||
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
|
||||
return LLVMPointerType::get(IntegerType::get(context, 32));
|
||||
}
|
||||
|
||||
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// ArgType
|
||||
IntegerType::get(context, 32),
|
||||
// ElementType
|
||||
IntegerType::get(context, 32),
|
||||
// Rank
|
||||
IntegerType::get(context, 32),
|
||||
// Extents
|
||||
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||
// IsStatic
|
||||
// IntegerType::get(context, 32),
|
||||
});
|
||||
}
|
||||
|
||||
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// ArgType
|
||||
IntegerType::get(context, 32),
|
||||
// ElementType
|
||||
IntegerType::get(context, 32),
|
||||
// Rank
|
||||
IntegerType::get(context, 32),
|
||||
// Extents
|
||||
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||
// IsStatic
|
||||
// IntegerType::get(context, 32),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVM type for refbackrt::FuncDescriptor.
|
||||
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(context,
|
||||
{
|
||||
// Name length.
|
||||
IntegerType::get(context, 32),
|
||||
// Name chars.
|
||||
getInt8PointerType(context),
|
||||
// Type-erased function pointer.
|
||||
getInt8PointerType(context),
|
||||
// Number of inputs.
|
||||
IntegerType::get(context, 32),
|
||||
// Number of outputs.
|
||||
IntegerType::get(context, 32),
|
||||
});
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// Name length.
|
||||
IntegerType::get(context, 32),
|
||||
// Name chars.
|
||||
getInt8PointerType(context),
|
||||
// Type-erased function pointer.
|
||||
getInt8PointerType(context),
|
||||
// Number of inputs.
|
||||
IntegerType::get(context, 32),
|
||||
// Number of outputs.
|
||||
IntegerType::get(context, 32),
|
||||
// Argument descriptors
|
||||
LLVMPointerType::get(getInputDescriptorTy(context)),
|
||||
// Result Descriptors
|
||||
LLVMPointerType::get(getOutputDescriptorTy(context)),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVM type for refbackrt::ModuleDescriptor.
|
||||
|
@ -92,8 +137,8 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
|||
// TODO: Deduplicate strings.
|
||||
std::string msgNulTerminated = msg.getValue().str();
|
||||
msgNulTerminated.push_back('\0');
|
||||
auto arrayTy = LLVMArrayType::get(
|
||||
IntegerType::get(module.getContext(), 8), msgNulTerminated.size());
|
||||
auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
|
||||
msgNulTerminated.size());
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
|
||||
|
@ -129,9 +174,9 @@ public:
|
|||
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
|
||||
op.msgAttr(), rewriter, op.getLoc());
|
||||
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
|
||||
auto c0 = rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), IntegerType::get(context, 32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
|
||||
IntegerType::get(context, 32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
auto msg =
|
||||
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
|
||||
msgArray, ValueRange({c0, c0}));
|
||||
|
@ -181,16 +226,181 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
|||
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
|
||||
|
||||
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
auto arrayTy =
|
||||
LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
|
||||
funcMetadata.funcName().size());
|
||||
auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
|
||||
funcMetadata.funcName().size());
|
||||
std::string llvmSymbolName =
|
||||
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
|
||||
auto global = builder.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
|
||||
globalsByName[funcMetadata.funcName()] = global;
|
||||
|
||||
// Create constants for the input / output shapes
|
||||
if (funcMetadata.inputShapes().hasValue()) {
|
||||
auto i32ArrayInputSymbolName =
|
||||
(Twine("__npcomp_internal_constant_input_shapes_") +
|
||||
funcMetadata.funcName())
|
||||
.str();
|
||||
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
|
||||
auto inputI32ArrayTy =
|
||||
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
|
||||
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
i32ArrayInputSymbolName,
|
||||
/*value=*/funcMetadata.inputShapes().getValue());
|
||||
|
||||
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
|
||||
}
|
||||
|
||||
if (funcMetadata.outputShapes().hasValue()) {
|
||||
auto i32ArrayOutputSymbolName =
|
||||
(Twine("__npcomp_internal_constant_output_shapes_") +
|
||||
funcMetadata.funcName())
|
||||
.str();
|
||||
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
|
||||
auto outputI32ArrayTy =
|
||||
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
|
||||
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
i32ArrayOutputSymbolName,
|
||||
/*value=*/funcMetadata.outputShapes().getValue());
|
||||
|
||||
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
|
||||
}
|
||||
}
|
||||
|
||||
auto updateDescriptor = [&](Value &descriptor, Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
descriptor = builder.create<LLVM::InsertValueOp>(
|
||||
loc, descriptor, value,
|
||||
/*position=*/builder.getI32ArrayAttr(position));
|
||||
};
|
||||
auto updateDescriptorWithI32Attr =
|
||||
[&](Value &descriptor, Attribute attr,
|
||||
std::initializer_list<int32_t> position) {
|
||||
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
|
||||
updateDescriptor(descriptor, constant, position);
|
||||
};
|
||||
|
||||
// Create global input descriptors
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
std::string llvmInputSymbolName =
|
||||
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
|
||||
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
|
||||
auto inputDescriptorArrayTy =
|
||||
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
|
||||
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, inputDescriptorArrayTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
|
||||
Value inputDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
|
||||
|
||||
for (int i = 0; i < funcMetadata.numInputs(); i++) {
|
||||
// Arg Type
|
||||
if (!funcMetadata.inputArgTypes().hasValue())
|
||||
funcMetadata.emitError()
|
||||
<< "numInputs > 0 but there are no inputArgTypes?";
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||
funcMetadata.inputArgTypes()->getValue(i),
|
||||
{i, 0});
|
||||
// Element Type
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||
funcMetadata.inputElementTypes()->getValue(i),
|
||||
{i, 1});
|
||||
|
||||
// Rank
|
||||
// auto inputShapesType =
|
||||
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
|
||||
auto rank = funcMetadata.inputRanks()->getValue(i);
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
|
||||
|
||||
// Shape
|
||||
// Each shape array is derived by offseting of kMaxRank * arg index
|
||||
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||
loc, inputShapesByName[funcMetadata.funcName()]);
|
||||
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||
loc, IntegerType::get(builder.getContext(), 32),
|
||||
builder.getI32IntegerAttr(i * kMaxRank));
|
||||
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||
ValueRange({c0, cShapeOffset}));
|
||||
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
|
||||
|
||||
inputDescriptorsByName[funcMetadata.funcName()] =
|
||||
std::move(inputDescriptorArrayGlobal);
|
||||
}
|
||||
|
||||
// Create global output descriptors
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
std::string llvmOutputSymbolName =
|
||||
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
|
||||
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
|
||||
auto outputDescriptorArrayTy =
|
||||
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
|
||||
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, outputDescriptorArrayTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
|
||||
Value outputDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
|
||||
|
||||
for (int i = 0; i < funcMetadata.numOutputs(); i++) {
|
||||
if (!funcMetadata.outputArgTypes().hasValue())
|
||||
funcMetadata.emitError()
|
||||
<< "numOutputs > 0 but there are no outputArgTypes?";
|
||||
// Arg Type
|
||||
updateDescriptorWithI32Attr(outputDescriptorArray,
|
||||
funcMetadata.outputArgTypes()->getValue(i),
|
||||
{i, 0});
|
||||
// Element Type
|
||||
updateDescriptorWithI32Attr(
|
||||
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
|
||||
{i, 1});
|
||||
|
||||
// Rank
|
||||
// auto outputShapesType =
|
||||
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
|
||||
auto rank = funcMetadata.outputRanks()->getValue(i);
|
||||
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
|
||||
|
||||
// Shapes
|
||||
// Offset by kMaxRank * arg index
|
||||
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||
loc, outputShapesByName[funcMetadata.funcName()]);
|
||||
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||
loc, IntegerType::get(builder.getContext(), 32),
|
||||
builder.getI32IntegerAttr(i * kMaxRank));
|
||||
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||
ValueRange({c0, cShapeOffset}));
|
||||
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
|
||||
|
||||
outputDescriptorsByName[funcMetadata.funcName()] =
|
||||
outputDescriptorArrayGlobal;
|
||||
}
|
||||
|
||||
// This must match FuncDescriptor in the runtime.
|
||||
|
@ -201,31 +411,23 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
|||
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
"__npcomp_func_descriptors",
|
||||
/*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
// Build the initializer.
|
||||
Value funcDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
|
||||
auto updateDescriptor = [&](Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
funcDescriptorArray = builder.create<LLVM::InsertValueOp>(
|
||||
loc, funcDescriptorArray, value,
|
||||
/*position=*/builder.getI32ArrayAttr(position));
|
||||
};
|
||||
auto updateDescriptorWithI32Attr =
|
||||
[&](Attribute attr, std::initializer_list<int32_t> position) {
|
||||
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
|
||||
updateDescriptor(constant, position);
|
||||
};
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
|
||||
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
|
||||
auto funcMetadata = funcMetadataAndIndex.value();
|
||||
int32_t index = funcMetadataAndIndex.index();
|
||||
|
||||
// Name length.
|
||||
updateDescriptorWithI32Attr(
|
||||
funcDescriptorArray,
|
||||
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
|
||||
|
||||
// Name chars.
|
||||
|
@ -234,7 +436,7 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
|||
auto funcNamePtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt8PointerType(builder.getContext()), funcNameArray,
|
||||
ValueRange({c0, c0}));
|
||||
updateDescriptor(funcNamePtr, {index, 1});
|
||||
updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
|
||||
|
||||
// Function pointer.
|
||||
//
|
||||
|
@ -247,13 +449,31 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
|||
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
|
||||
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
|
||||
loc, getInt8PointerType(builder.getContext()), funcAddress);
|
||||
updateDescriptor(typeErasedFuncAddress, {index, 2});
|
||||
updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
|
||||
|
||||
// Number of inputs.
|
||||
updateDescriptorWithI32Attr(funcMetadata.numInputsAttr(), {index, 3});
|
||||
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||
funcMetadata.numInputsAttr(), {index, 3});
|
||||
|
||||
// Number of outputs.
|
||||
updateDescriptorWithI32Attr(funcMetadata.numOutputsAttr(), {index, 4});
|
||||
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||
funcMetadata.numOutputsAttr(), {index, 4});
|
||||
|
||||
// Input descriptors
|
||||
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||
loc, inputDescriptorsByName[funcMetadata.funcName()]);
|
||||
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
|
||||
inputDescriptorsArrayAddress);
|
||||
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
|
||||
|
||||
// Output descriptors
|
||||
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||
loc, outputDescriptorsByName[funcMetadata.funcName()]);
|
||||
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
|
||||
outputDescriptorsArrayAddress);
|
||||
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
|
||||
|
@ -379,6 +599,16 @@ static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
|
|||
/*memorySpace=*/0));
|
||||
}
|
||||
|
||||
static Type getDoubleType(MLIRContext *context) {
|
||||
LLVMTypeConverter converter(context);
|
||||
return converter.convertType(FloatType::getF64(context));
|
||||
}
|
||||
|
||||
static Type getFloatType(MLIRContext *context) {
|
||||
LLVMTypeConverter converter(context);
|
||||
return converter.convertType(FloatType::getF32(context));
|
||||
}
|
||||
|
||||
// Writes out the logical results of the wrapper function through the void**
|
||||
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
|
||||
// only supports a single return type (or void/no results), the logic here needs
|
||||
|
@ -393,12 +623,18 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
|
|||
return;
|
||||
Value result = callToWrapped.getResult(0);
|
||||
auto ty = result.getType();
|
||||
|
||||
// 1 logical result.
|
||||
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
|
||||
Value addr =
|
||||
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||
return;
|
||||
} else if (ty == getFloatType(ty.getContext())) {
|
||||
Value addr =
|
||||
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||
return;
|
||||
}
|
||||
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
|
||||
auto structType = ty.cast<LLVMStructType>();
|
||||
|
|
|
@ -21,6 +21,16 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// Since input/output shapes are not hyper-rectangular we specify
|
||||
// a maximum rank for each input shape such that shapes are padded
|
||||
// out to kMaxRank at the ABI boundary. That way we can represent
|
||||
// shapes using a traditional DenseElementsAttr.
|
||||
//
|
||||
// NOTE: When changing this parameter, also change the same `kMaxRank`
|
||||
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
|
||||
// stays consistent.
|
||||
static constexpr int kMaxRank = 6;
|
||||
|
||||
// Get the type used to represent MemRefType `type` on ABI boundaries.
|
||||
// For convenience we do a cast to MemRefType internally.
|
||||
static Type getABIMemrefType(Type type) {
|
||||
|
@ -38,7 +48,99 @@ static bool expressibleWithRefbackrtABI(FunctionType type) {
|
|||
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
|
||||
return llvm::all_of(
|
||||
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
||||
[](Type t) { return t.isa<MemRefType>(); });
|
||||
[](Type t) {
|
||||
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
|
||||
});
|
||||
}
|
||||
|
||||
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
|
||||
// Must stay aligned with CompilerDataStructures::ABIArgType enum
|
||||
static uint32_t getIntReprForABIType(Type type) {
|
||||
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
|
||||
return 1;
|
||||
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
|
||||
switch (floatTy.getWidth()) {
|
||||
case 32:
|
||||
return 2;
|
||||
case 64:
|
||||
return 3;
|
||||
default:
|
||||
assert(false && "Unsupported float bit width");
|
||||
}
|
||||
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
|
||||
}
|
||||
// assert(false && "couldn't get IntReprForABIType");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Must stay aligned with CompilerDataStructures::ABIElementType enum
|
||||
static uint32_t getIntReprForABIElementType(Type type) {
|
||||
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
|
||||
auto elemTy = shapedTy.getElementType();
|
||||
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
|
||||
switch (floatTy.getWidth()) {
|
||||
case 32:
|
||||
return 1;
|
||||
default:
|
||||
assert(false && "Unsupported tensor element type");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static SmallVector<int32_t, kMaxRank>
|
||||
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
|
||||
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
if (!shapedType.hasRank()) {
|
||||
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||
}
|
||||
|
||||
auto shape = shapedType.getShape();
|
||||
auto shapeRank = shapedType.getRank();
|
||||
if (shapeRank <= maxRank) {
|
||||
SmallVector<int32_t, kMaxRank> extendedShape;
|
||||
// Push back all the values of the shape
|
||||
for (auto extentAndIndex : llvm::enumerate(shape)) {
|
||||
auto extent = extentAndIndex.value();
|
||||
auto index = extentAndIndex.index();
|
||||
if (shapedType.isDynamic(index)) {
|
||||
extendedShape.push_back(-1);
|
||||
} else {
|
||||
extendedShape.push_back(extent);
|
||||
}
|
||||
}
|
||||
|
||||
// Pad whatever is left so we have even vectors
|
||||
auto padRank = maxRank - shapeRank;
|
||||
for (int i = 0; i < padRank; i++)
|
||||
extendedShape.push_back(0xDEAD'BEEF);
|
||||
|
||||
return extendedShape;
|
||||
} else {
|
||||
assert(false && "unsupported rank");
|
||||
}
|
||||
}
|
||||
|
||||
// Represent Scalar's as all 1's.
|
||||
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||
}
|
||||
|
||||
int32_t getRankForType(Type type) {
|
||||
// Returns a rank of -1 if the tensor is unranked
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
return shapedType.hasRank() ? shapedType.getRank() : -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t hasStaticShape(Type type) {
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
return shapedType.hasStaticShape() ? 1 : 0;
|
||||
}
|
||||
// Assume scalars and non-shaped type things are static
|
||||
return 1;
|
||||
}
|
||||
|
||||
static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||
|
@ -59,10 +161,131 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
|||
}
|
||||
// TODO: Add richer information here such as expected shapes and element
|
||||
// types.
|
||||
builder.create<refbackrt::FuncMetadataOp>(
|
||||
func.getLoc(), builder.getSymbolRefAttr(func.getName()),
|
||||
builder.getI32IntegerAttr(func.getNumArguments()),
|
||||
builder.getI32IntegerAttr(func.getNumResults()));
|
||||
SmallVector<uint32_t, 6> inputABIArgTypes;
|
||||
SmallVector<uint32_t, 6> inputABIElementTypes;
|
||||
SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
|
||||
SmallVector<uint32_t, 6> inputABIRanks;
|
||||
// SmallVector<uint32_t, 6> inputIsStatic;
|
||||
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
|
||||
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
|
||||
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
|
||||
inputABIShapes.push_back(
|
||||
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
|
||||
inputABIRanks.push_back(getRankForType(inputArgType));
|
||||
// inputIsStatic.push_back(hasStaticShape(inputArgType));
|
||||
}
|
||||
|
||||
SmallVector<uint32_t, 6> outputABIArgTypes;
|
||||
SmallVector<uint32_t, 6> outputABIElementTypes;
|
||||
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
|
||||
SmallVector<uint32_t, 6> outputABIRanks;
|
||||
SmallVector<uint32_t, 6> outputIsStatic;
|
||||
for (const auto &outputArgType : func.getCallableResults()) {
|
||||
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
|
||||
outputABIElementTypes.push_back(
|
||||
getIntReprForABIElementType(outputArgType));
|
||||
outputABIShapes.push_back(
|
||||
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
|
||||
outputABIRanks.push_back(getRankForType(outputArgType));
|
||||
// outputIsStatic.push_back(hasStaticShape(outputArgType));
|
||||
}
|
||||
|
||||
auto i32Type = builder.getIntegerType(32);
|
||||
auto inputABIDataType =
|
||||
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
|
||||
auto inputABIElementType =
|
||||
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
|
||||
auto inputABIShapesType = RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
|
||||
kMaxRank},
|
||||
i32Type);
|
||||
auto inputABIRanksType =
|
||||
RankedTensorType::get(inputABIRanks.size(), i32Type);
|
||||
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
|
||||
// i32Type);
|
||||
auto outputABIDataType =
|
||||
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
|
||||
auto outputABIElementType =
|
||||
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
|
||||
auto outputABIShapesType = RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
|
||||
kMaxRank},
|
||||
i32Type);
|
||||
auto outputABIRanksType =
|
||||
RankedTensorType::get(outputABIRanks.size(), i32Type);
|
||||
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
|
||||
// i32Type);
|
||||
|
||||
// TODO(brycearden): I'm sure there's a cleaner way to do this
|
||||
auto flattenABIShapes =
|
||||
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
|
||||
SmallVector<int32_t, 32> ret;
|
||||
for (auto &shape : shapes)
|
||||
for (auto &dim : shape)
|
||||
ret.push_back(dim);
|
||||
return ret;
|
||||
};
|
||||
|
||||
SmallVector<NamedAttribute, 16> namedAttrs;
|
||||
|
||||
// Add attributes that are valid for every func (funcName, numInputs,
|
||||
// numOutputs)
|
||||
namedAttrs.push_back(
|
||||
std::make_pair(Identifier::get("funcName", module.getContext()),
|
||||
builder.getSymbolRefAttr(func.getName())));
|
||||
namedAttrs.push_back(
|
||||
std::make_pair(Identifier::get("numInputs", module.getContext()),
|
||||
builder.getI32IntegerAttr(func.getNumArguments())));
|
||||
namedAttrs.push_back(
|
||||
std::make_pair(Identifier::get("numOutputs", module.getContext()),
|
||||
builder.getI32IntegerAttr(func.getNumResults())));
|
||||
|
||||
if (inputABIArgTypes.size()) {
|
||||
// Only add input information if there are inputs
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputArgTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIDataType,
|
||||
llvm::makeArrayRef(inputABIArgTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputElementTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIElementType,
|
||||
llvm::makeArrayRef(inputABIElementTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputRanks", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIRanksType,
|
||||
llvm::makeArrayRef(inputABIRanks))));
|
||||
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputShapes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
inputABIShapesType,
|
||||
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
|
||||
}
|
||||
|
||||
if (outputABIArgTypes.size()) {
|
||||
// Only add output information if there are outptus
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputArgTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(outputABIDataType,
|
||||
llvm::makeArrayRef(outputABIArgTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputElementTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
outputABIElementType,
|
||||
llvm::makeArrayRef(outputABIElementTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputRanks", func.getContext()),
|
||||
DenseIntElementsAttr::get(outputABIRanksType,
|
||||
llvm::makeArrayRef(outputABIRanks))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputShapes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
outputABIShapesType,
|
||||
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
|
||||
}
|
||||
|
||||
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
|
||||
ArrayRef<Value>{}, namedAttrs);
|
||||
|
||||
if (!expressibleWithRefbackrtABI(func.getType()))
|
||||
return func.emitError() << "func not expressible with refbackrt ABI";
|
||||
|
|
|
@ -23,6 +23,40 @@ namespace refbackrt {
|
|||
// LowerToLLVM.cpp for more details.
|
||||
typedef void ABIFunc(void **, void **);
|
||||
|
||||
enum class ABIArgType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kMemref,
|
||||
kF32,
|
||||
kF64,
|
||||
};
|
||||
|
||||
enum class ABIElementType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kF32,
|
||||
};
|
||||
|
||||
struct InputDescriptor {
|
||||
ABIArgType abiType;
|
||||
ABIElementType elementType;
|
||||
|
||||
std::int32_t rank;
|
||||
std::int32_t* extents;
|
||||
|
||||
// TODO(brycearden): Change to bool at ABI boundary
|
||||
// std::int32_t isStatic;
|
||||
};
|
||||
|
||||
struct OutputDescriptor {
|
||||
ABIArgType abiType;
|
||||
ABIElementType elementType;
|
||||
|
||||
std::int32_t rank;
|
||||
std::int32_t* extents;
|
||||
|
||||
// TODO(brycearden): Change to bool at ABI boundary
|
||||
//std::int32_t isStatic;
|
||||
};
|
||||
|
||||
struct FuncDescriptor {
|
||||
// The length of the function name.
|
||||
std::int32_t nameLen;
|
||||
|
@ -35,9 +69,9 @@ struct FuncDescriptor {
|
|||
std::int32_t numInputs;
|
||||
// The number of outputs of the function.
|
||||
std::int32_t numOutputs;
|
||||
// TODO: Add arg/result descriptors and other metadata.
|
||||
// With those descriptors we can do type and shape checking for each
|
||||
// argument.
|
||||
// TODO: Add shape checking to arg / result descriptor(s)
|
||||
InputDescriptor *inputDescriptors;
|
||||
OutputDescriptor *outputDescriptors;
|
||||
};
|
||||
|
||||
// The top-level entry point of the module metadata emitted by the
|
||||
|
|
|
@ -118,12 +118,38 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
|
|||
|
||||
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
||||
switch (type) {
|
||||
case ElementType::NONE:
|
||||
return 0;
|
||||
case ElementType::F32:
|
||||
return 4;
|
||||
}
|
||||
llvm_unreachable("unsupported dtype");
|
||||
}
|
||||
|
||||
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
|
||||
switch (type) {
|
||||
case ElementType::NONE:
|
||||
return "NONE";
|
||||
case ElementType::F32:
|
||||
return "F32";
|
||||
}
|
||||
llvm_unreachable("unsupported element type string");
|
||||
}
|
||||
|
||||
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
|
||||
switch (type) {
|
||||
case ArgType::kNone:
|
||||
return "kNone";
|
||||
case ArgType::kTensor:
|
||||
return "kTensor";
|
||||
case ArgType::kF32:
|
||||
return "kF32";
|
||||
case ArgType::kF64:
|
||||
return "kF64";
|
||||
}
|
||||
llvm_unreachable("unsupported arg type string");
|
||||
}
|
||||
|
||||
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
|
||||
void *data) {
|
||||
return Ref<Tensor>(createRaw(extents, type, data));
|
||||
|
@ -192,10 +218,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
||||
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
||||
// more complex though (and maybe impossible given the current abstractions).
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
inputUnrankedMemrefs[i] =
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||
}
|
||||
//
|
||||
// Create a type-erased list of "packed inputs" to pass to the
|
||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||
// one LLVM/C ABI argument to the underlying function.
|
||||
|
@ -204,16 +227,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
// "explode" the unranked memref descriptors on the underlying function
|
||||
// into separate arguments for the rank and pointer-to-descriptor.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||
auto idx = 2 * i;
|
||||
if (inputs[i].isTensor()) {
|
||||
inputUnrankedMemrefs[i] =
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||
} else if (inputs[i].isScalar()) {
|
||||
packedInputs[idx] = ToVoidPtr(&inputs[i]);
|
||||
} else {
|
||||
assert(false && "unsupported input RtValue type");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a type-erased list of "packed output" to pass to the
|
||||
// LLVM/C ABI wrapper function.
|
||||
//
|
||||
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
||||
// corresponds to a single UnrankedMemref (not "exploded").
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||
if (outputs[i].isTensor()) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||
} else if (outputs[i].isScalar()) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Actually invoke the function!
|
||||
|
@ -223,11 +260,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
// TODO: Avoid needing to make a deep copy.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
// TODO: Have compiler emit the element type in the metadata.
|
||||
auto elementType = ElementType::F32;
|
||||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||
elementType);
|
||||
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||
if (outputs[i].isTensor()) {
|
||||
auto elementType = ElementType::F32;
|
||||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||
elementType);
|
||||
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||
} else if (outputs[i].isFloat()) {
|
||||
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||
|
@ -239,24 +280,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
|
||||
// Free the output buffers.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
// Multiple returned memrefs can point into the same underlying
|
||||
// malloc allocation. Do a linear scan to see if any of the previously
|
||||
// deallocated buffers already freed this pointer.
|
||||
bool bufferNeedsFreeing = true;
|
||||
for (int j = 0; j < i; j++) {
|
||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||
bufferNeedsFreeing = false;
|
||||
if (outputs[i].isRef()) {
|
||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
// Multiple returned memrefs can point into the same underlying
|
||||
// malloc allocation. Do a linear scan to see if any of the previously
|
||||
// deallocated buffers already freed this pointer.
|
||||
bool bufferNeedsFreeing = true;
|
||||
for (int j = 0; j < i; j++) {
|
||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||
bufferNeedsFreeing = false;
|
||||
}
|
||||
if (!bufferNeedsFreeing)
|
||||
std::free(allocatedPtr);
|
||||
}
|
||||
if (!bufferNeedsFreeing)
|
||||
std::free(allocatedPtr);
|
||||
}
|
||||
|
||||
// Free the input buffers.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
if (!inputs[i].isRef())
|
||||
continue;
|
||||
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
bool bufferNeedsFreeing = true;
|
||||
for (int j = 0, je = outputs.size(); j < je; j++) {
|
||||
if (!outputs[j].isRef())
|
||||
continue;
|
||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||
bufferNeedsFreeing = false;
|
||||
}
|
||||
|
@ -274,6 +321,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
|
||||
// Free the output descriptors.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
if (!outputs[i].isRef())
|
||||
continue;
|
||||
// The LLVM lowering guarantees that each returned unranked memref
|
||||
// descriptor is separately malloc'ed, so no need to do anything special
|
||||
// like we had to do for the allocatedPtr's.
|
||||
|
@ -281,10 +330,81 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
|||
}
|
||||
// Free the input descriptors.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
if (!inputs[i].isRef())
|
||||
continue;
|
||||
std::free(inputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
static InputArgInfo
|
||||
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
|
||||
InputArgInfo ret;
|
||||
|
||||
// Set arg / element types accordingly
|
||||
switch (inputDescriptor.abiType) {
|
||||
case ABIArgType::kNone:
|
||||
ret.argType = ArgType::kNone;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kMemref:
|
||||
ret.argType = ArgType::kTensor;
|
||||
ret.elementType = ElementType::F32;
|
||||
break;
|
||||
case ABIArgType::kF32:
|
||||
ret.argType = ArgType::kF32;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kF64:
|
||||
ret.argType = ArgType::kF64;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
default:
|
||||
assert(false && "need to update external internal map");
|
||||
}
|
||||
|
||||
// Extract shape information
|
||||
ret.rank = inputDescriptor.rank;
|
||||
for (int i = 0; i < inputDescriptor.rank; i++) {
|
||||
ret.extents[i] = inputDescriptor.extents[i];
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static OutputArgInfo
|
||||
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
|
||||
OutputArgInfo ret;
|
||||
|
||||
// Set arg / element types accordingly
|
||||
switch (outputDescriptor.abiType) {
|
||||
case ABIArgType::kNone:
|
||||
ret.argType = ArgType::kNone;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kMemref:
|
||||
ret.argType = ArgType::kTensor;
|
||||
ret.elementType = ElementType::F32;
|
||||
break;
|
||||
case ABIArgType::kF32:
|
||||
ret.argType = ArgType::kF32;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kF64:
|
||||
ret.argType = ArgType::kF64;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
default:
|
||||
assert(false && "need to update external internal map");
|
||||
}
|
||||
|
||||
// Extract shape information
|
||||
ret.rank = outputDescriptor.rank;
|
||||
for (int i = 0; i < outputDescriptor.rank; i++) {
|
||||
ret.extents[i] = outputDescriptor.extents[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef functionName,
|
||||
FunctionMetadata &outMetadata) {
|
||||
|
@ -293,5 +413,107 @@ LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
|||
return failure();
|
||||
outMetadata.numInputs = descriptor->numInputs;
|
||||
outMetadata.numOutputs = descriptor->numOutputs;
|
||||
|
||||
for (int i = 0; i < descriptor->numInputs; i++) {
|
||||
outMetadata.inputArgInfos[i] =
|
||||
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < descriptor->numOutputs; i++) {
|
||||
outMetadata.outputArgInfos[i] =
|
||||
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
|
||||
const InputArgInfo &info) {
|
||||
if (value.isTensor()) {
|
||||
auto refTensor = value.toTensor();
|
||||
|
||||
// Don't bother checking shapes for unranked tensors
|
||||
if (info.rank < 0)
|
||||
return success();
|
||||
|
||||
if (refTensor->getRank() != info.rank)
|
||||
return failure();
|
||||
|
||||
auto tensorExtents = refTensor->getExtents();
|
||||
for (int i = 0; i < info.rank; i++) {
|
||||
// If a dimension is dynamic, it is encoded as extent = -1
|
||||
// and we should skip checking over that dimension
|
||||
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
|
||||
const InputArgInfo &info) {
|
||||
// Generic checks based on argType(s)
|
||||
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
|
||||
(value.isFloat() && info.argType != ArgType::kF32))
|
||||
return failure();
|
||||
|
||||
if (value.isRef()) {
|
||||
// Will need special error checking for ref-counted types
|
||||
// Currently only f32 tensors are supported
|
||||
if (value.isTensor()) {
|
||||
auto refTensor = value.toTensor();
|
||||
if (refTensor->getElementType() != ElementType::F32)
|
||||
return failure();
|
||||
} else {
|
||||
assert(false && "Unsupported input type checking for Ref type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
|
||||
constexpr int32_t kDynamicConstantShape = 100;
|
||||
switch (info.argType) {
|
||||
case ArgType::kTensor: {
|
||||
// HACK: for dynamic dims the shape will be negative, so for now we are
|
||||
// just going to create a tensor of size kDynamicConstantShape
|
||||
std::array<int32_t, kMaxRank> tensorShape;
|
||||
for (int i = 0; i < info.rank; i++) {
|
||||
tensorShape[i] =
|
||||
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
|
||||
}
|
||||
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
|
||||
int numel = 1;
|
||||
for (int i = 0; i < info.rank; i++)
|
||||
numel *= shape[i];
|
||||
|
||||
void *data;
|
||||
switch (info.elementType) {
|
||||
case ElementType::F32: {
|
||||
auto byteSize = numel * sizeof(float);
|
||||
data = static_cast<void *>(aligned_alloc(32, byteSize));
|
||||
memset(data, 0, byteSize);
|
||||
return RtValue(Tensor::create(shape, ElementType::F32, data));
|
||||
break;
|
||||
}
|
||||
default: { assert(false && "unknown output tensor type"); }
|
||||
}
|
||||
|
||||
// The Tensor::create function will malloc and memcpy the data
|
||||
// into the Tensor object, so after we need to free our
|
||||
// temporary data buffer
|
||||
assert(data && "data ptr must exist");
|
||||
auto refTensor = Tensor::create(shape, ElementType::F32, data);
|
||||
free(data);
|
||||
return RtValue(refTensor);
|
||||
}
|
||||
case ArgType::kF32: {
|
||||
return RtValue(-20.0f);
|
||||
}
|
||||
default: {
|
||||
assert(false && "Don't know how to handle this artType");
|
||||
return RtValue();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,9 +3,17 @@
|
|||
// CHECK: refbackrt.module_metadata
|
||||
refbackrt.module_metadata {
|
||||
// CHECK: refbackrt.func_metadata
|
||||
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||
// TODO(brycearden): Encode unranked memrefs in the ABI
|
||||
refbackrt.func_metadata {
|
||||
funcName = @f,
|
||||
numInputs = 1 : i32,
|
||||
numOutputs = 0 : i32,
|
||||
inputArgTypes = dense<1> : tensor<1xi32>,
|
||||
inputElementTypes = dense<1> : tensor<1xi32>,
|
||||
inputRanks = dense<-1> : tensor<1xi32>,
|
||||
inputShapes = dense<1> : tensor<4xi32>}
|
||||
}
|
||||
|
||||
func @f(%arg0: memref<*xf32>) {
|
||||
func @f(%arg0: tensor<*xf32>) {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -1,165 +0,0 @@
|
|||
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test input/output arg marshaling.
|
||||
|
||||
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
||||
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
||||
// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
||||
// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: llvm.return
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
||||
// CHECK: llvm.return
|
||||
// CHECK: }
|
||||
|
||||
/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
||||
// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> ()
|
||||
// CHECK: llvm.return
|
||||
// CHECK: }
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(15 : i32) : i32
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(15 : i32) : i32
|
||||
// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
|
||||
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_21]][1 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_23]][1 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(15 : i32) : i32
|
||||
// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_25]][2 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
|
||||
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][2 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(2 : i32) : i32
|
||||
// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][2 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: llvm.return %[[VAL_37]] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>>
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>> to !llvm.ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||
// CHECK: }
|
||||
|
||||
refbackrt.module_metadata {
|
||||
refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||
refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
||||
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
|
||||
}
|
||||
|
||||
func @inputs1results0(%arg0: memref<*xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> {
|
||||
return %arg0 : memref<*xf32>
|
||||
}
|
||||
|
||||
func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) {
|
||||
return %arg0, %arg0 : memref<*xf32>, memref<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test emission of compiler runtime functions.
|
||||
|
||||
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00")
|
||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
|
||||
|
||||
// CHECK-LABEL: llvm.func @calls_abort_if(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: i1) {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @[[STRSYM]] : !llvm.ptr<array<4 x i8>>
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<4 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
||||
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_3:.*]], %[[VAL_2]]) : (i1, !llvm.ptr<i8>) -> ()
|
||||
// CHECK: llvm.return
|
||||
|
||||
func @calls_abort_if(%arg0: i1) {
|
||||
refbackrt.abort_if %arg0, "msg"
|
||||
return
|
||||
}
|
|
@ -3,8 +3,14 @@
|
|||
// Test module metadata.
|
||||
|
||||
// CHECK: refbackrt.module_metadata
|
||||
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
|
||||
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
|
||||
// CHECK-NEXT: refbackrt.func_metadata
|
||||
// CHECK-SAME: funcName = @f_2inputs_0outputs
|
||||
// CHECK-SAME: numInputs = 2
|
||||
// CHECK-SAME: numOutputs = 0
|
||||
// CHECK-NEXT: refbackrt.func_metadata
|
||||
// CHECK-SAME: funcName = @f_1input_2outputs
|
||||
// CHECK-SAME: numInputs = 1
|
||||
// CHECK-SAME: numOutputs = 2
|
||||
|
||||
// This function only exists to test its metadata above.
|
||||
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
||||
|
|
|
@ -8,5 +8,4 @@
|
|||
func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
// RUN: not npcomp-run-mlir %s \
|
||||
// RUN: -invoke invalid_input_shape \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<2x2x2x2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s -check-prefix=ARG0-INVALID
|
||||
|
||||
// RUN: not npcomp-run-mlir %s \
|
||||
// RUN: -invoke invalid_input_shape_arg1 \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<1x2x5xf32>" \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<1x2x10xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s -check-prefix=ARG1-INVALID
|
||||
|
||||
// ARG0-INVALID: invoking 'invalid_input_shape': input shape mismatch (%arg0).
|
||||
// ARG0-INVALID-SAME: actual (provided by user): (2x2x2x2)
|
||||
// ARG0-INVALID-SAME: expected (from compiler): (1x2x3x4)
|
||||
func @invalid_input_shape(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
return %arg0: tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
// ARG1-INVALID: invoking 'invalid_input_shape_arg1': input shape mismatch (%arg1)
|
||||
// ARG1-INVALID-SAME: actual (provided by user): (1x2x10)
|
||||
// ARG1-INVALID-SAME: expected (from compiler): (1x4x?)
|
||||
func @invalid_input_shape_arg1(%arg0: tensor<1x2x?xf32>, %arg1: tensor<1x4x?xf32>) {
|
||||
return
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// RUN: not npcomp-run-mlir %s \
|
||||
// RUN: -invoke expects_one_tensor \
|
||||
// RUN: -arg-value="1.0 : f32" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
|
||||
// CHECK-SAME: actual (provided by user): Float
|
||||
// CHECK-SAME: expected (from compiler): kTensor
|
||||
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
|
@ -2,10 +2,21 @@
|
|||
// RUN: -invoke scalar \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
// RUN: | FileCheck %s --check-prefix=SCALAR
|
||||
|
||||
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
|
||||
// RUN: npcomp-run-mlir %s \
|
||||
// RUN: -invoke scalar_arg \
|
||||
// RUN: -arg-value="2.5 : f32" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s --check-prefix=SCALAR_ARG
|
||||
|
||||
// SCALAR: output #0: dense<2.000000e+00> : tensor<f32>
|
||||
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// SCALAR_ARG: output #0: 2.500000e+00 : f32
|
||||
func @scalar_arg(%arg0: f32) -> f32 {
|
||||
return %arg0 : f32
|
||||
}
|
|
@ -58,6 +58,14 @@ convertAttrToTensor(Attribute attr) {
|
|||
return make_string_error("unhandled argument");
|
||||
}
|
||||
|
||||
static Expected<float> convertAttrToFloat(Attribute attr) {
|
||||
auto type = attr.getType().dyn_cast<FloatType>();
|
||||
if (!type)
|
||||
return make_string_error("converting an argument to float that is not a FloatType");
|
||||
auto floatAttr = attr.dyn_cast<FloatAttr>();
|
||||
return floatAttr.getValue().convertToFloat();
|
||||
}
|
||||
|
||||
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
||||
createInputs(ArrayRef<StringRef> argValues) {
|
||||
MLIRContext context;
|
||||
|
@ -66,12 +74,22 @@ createInputs(ArrayRef<StringRef> argValues) {
|
|||
auto attr = parseAttribute(argValue, &context);
|
||||
if (!attr)
|
||||
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
||||
// TODO(brycearden): Handle multiple input types
|
||||
auto expectedTensor = convertAttrToTensor(attr);
|
||||
if (!expectedTensor)
|
||||
return expectedTensor.takeError();
|
||||
ret.push_back(std::move(*expectedTensor));
|
||||
|
||||
auto attrType = attr.getType();
|
||||
|
||||
if (attrType.isa<RankedTensorType>()) {
|
||||
auto expectedTensor = convertAttrToTensor(attr);
|
||||
if (!expectedTensor)
|
||||
return expectedTensor.takeError();
|
||||
ret.push_back(std::move(*expectedTensor));
|
||||
} else if (attrType.isa<FloatType>()) {
|
||||
auto expectedFloat = convertAttrToFloat(attr);
|
||||
if (!expectedFloat)
|
||||
return expectedFloat.takeError();
|
||||
ret.push_back(refbackrt::RtValue(*expectedFloat));
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -92,34 +110,40 @@ static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor
|
|||
return RankedTensorType::get(extents, elementType);
|
||||
}
|
||||
|
||||
static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor,
|
||||
static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
|
||||
Builder &builder) {
|
||||
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
||||
switch (tensor.getElementType()) {
|
||||
case refbackrt::ElementType::F32: {
|
||||
SmallVector<float, 100> values;
|
||||
auto *basePtr = tensor.getData<float>();
|
||||
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
||||
values.push_back(basePtr[i]);
|
||||
return DenseFPElementsAttr::get(type, values);
|
||||
}
|
||||
if (value.isTensor()) {
|
||||
auto& tensor = *(value.toTensor());
|
||||
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
||||
switch (tensor.getElementType()) {
|
||||
case refbackrt::ElementType::F32: {
|
||||
SmallVector<float, 100> values;
|
||||
auto *basePtr = tensor.getData<float>();
|
||||
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
||||
values.push_back(basePtr[i]);
|
||||
return DenseFPElementsAttr::get(type, values);
|
||||
}
|
||||
}
|
||||
} else if (value.isFloat()) {
|
||||
return builder.getF32FloatAttr(value.toFloat());
|
||||
} else {
|
||||
assert(false && "could not convert value to mlir attribute");
|
||||
}
|
||||
llvm_unreachable("unsupported dtype");
|
||||
}
|
||||
|
||||
static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
|
||||
static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
auto attr = convertToMLIRAttribute(tensor, builder);
|
||||
auto attr = convertToMLIRAttribute(value, builder);
|
||||
attr.print(os);
|
||||
}
|
||||
|
||||
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
||||
llvm::raw_ostream &os) {
|
||||
for (auto output : llvm::enumerate(outputs)) {
|
||||
assert(output.value().isTensor() && "only tensor outputs are supported.");
|
||||
os << "output #" << output.index() << ": ";
|
||||
printOutput(*output.value().toTensor().get(), os);
|
||||
printOutput(output.value(), os);
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
@ -150,9 +174,11 @@ Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
|
|||
auto expectedInputs = createInputs(argValues);
|
||||
if (!expectedInputs)
|
||||
return expectedInputs.takeError();
|
||||
|
||||
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
|
||||
if (!expectedOutputs)
|
||||
return expectedOutputs.takeError();
|
||||
|
||||
auto outputs = std::move(*expectedOutputs);
|
||||
printOutputs(outputs, llvm::outs());
|
||||
llvm::outs() << "SUCCESS\n";
|
||||
|
|
Loading…
Reference in New Issue