mirror of https://github.com/llvm/torch-mlir
[refbackrt] Scalar arg support
* Adds f32 scalar argument support across the ABI boundary. * Adds support for passing input type / shape information across the ABI boundary * Adds support for parsing / creating input FloatAttr's in `npcomp-run-mlir`pull/197/head
parent
703428eff4
commit
4591884d06
|
@ -68,12 +68,47 @@ def Refbackrt_FuncMetadataOp
|
||||||
let description = [{
|
let description = [{
|
||||||
Runtime metadata for a single func.
|
Runtime metadata for a single func.
|
||||||
|
|
||||||
TODO: Augment this with information for type/shape checking of arguments.
|
Contains type / shape information for arguments as described below:
|
||||||
|
|
||||||
|
* ArgType(s):
|
||||||
|
Integer value from `CompilerDataStructures.h` for each argument
|
||||||
|
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
|
||||||
|
* ElementType(s):
|
||||||
|
Certain input ArgType's also have an element type (e.g. Tensor<float>,
|
||||||
|
List<int>, etc.)
|
||||||
|
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
|
||||||
|
* Rank(s):
|
||||||
|
Integer value indicating the rank for each argument.
|
||||||
|
* Shape(s):
|
||||||
|
Flattened hyper-rectangular representation of the shapes for each argument.
|
||||||
|
Since each shape's size varies based on the Rank, we pad out the shapes
|
||||||
|
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
|
||||||
|
for details.
|
||||||
|
|
||||||
|
Shapes Example:
|
||||||
|
constexpr int kMaxRank = 6;
|
||||||
|
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
|
||||||
|
inputShapes = dense<...> : tensor<12xi32>
|
||||||
|
// 2 shapes with 6 elements each so that the LowerToLLVM pass
|
||||||
|
// where only the first `rank` values in each shape are valid.
|
||||||
|
//
|
||||||
|
// can update the struct(s) by just grabbing a pointer at
|
||||||
|
// %shape_ptr = %base + (kMaxRank * argIndex)
|
||||||
}];
|
}];
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
FlatSymbolRefAttr:$funcName,
|
FlatSymbolRefAttr:$funcName,
|
||||||
I32Attr:$numInputs,
|
I32Attr:$numInputs,
|
||||||
I32Attr:$numOutputs
|
I32Attr:$numOutputs,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$inputRanks,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$inputShapes,
|
||||||
|
// I32ElementsAttr:$inputIsStatic,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$outputRanks,
|
||||||
|
OptionalAttr<I32ElementsAttr>:$outputShapes
|
||||||
|
//I32ElementsAttr:$outputIsStatic
|
||||||
);
|
);
|
||||||
let results = (outs);
|
let results = (outs);
|
||||||
let assemblyFormat = "attr-dict";
|
let assemblyFormat = "attr-dict";
|
||||||
|
|
|
@ -31,6 +31,8 @@ public:
|
||||||
return std::memcmp(ptr, other.ptr, length) == 0;
|
return std::memcmp(ptr, other.ptr, length) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* str() { return ptr; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const char *ptr;
|
const char *ptr;
|
||||||
std::size_t length;
|
std::size_t length;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#define NPCOMP_RUNTIME_USERAPI_H
|
#define NPCOMP_RUNTIME_USERAPI_H
|
||||||
|
|
||||||
#include "npcomp/RefBackend/Runtime/Support.h"
|
#include "npcomp/RefBackend/Runtime/Support.h"
|
||||||
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
|
@ -105,9 +106,11 @@ private:
|
||||||
|
|
||||||
// The available data types.
|
// The available data types.
|
||||||
enum class ElementType : std::int32_t {
|
enum class ElementType : std::int32_t {
|
||||||
|
NONE,
|
||||||
F32,
|
F32,
|
||||||
};
|
};
|
||||||
std::int32_t getElementTypeByteSize(ElementType type);
|
std::int32_t getElementTypeByteSize(ElementType type);
|
||||||
|
StringRef getElementTypeAsStringRef(ElementType type);
|
||||||
|
|
||||||
// Representation of a tensor.
|
// Representation of a tensor.
|
||||||
class Tensor : public RefTarget {
|
class Tensor : public RefTarget {
|
||||||
|
@ -124,6 +127,12 @@ public:
|
||||||
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
|
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
|
||||||
ElementType elementType, void *data);
|
ElementType elementType, void *data);
|
||||||
|
|
||||||
|
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
|
||||||
|
ElementType elementType, void *data);
|
||||||
|
// Same as `create`, but returns a raw pointer.
|
||||||
|
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
|
||||||
|
ElementType elementType, void *data);
|
||||||
|
|
||||||
ElementType getElementType() const { return elementType; }
|
ElementType getElementType() const { return elementType; }
|
||||||
std::int32_t getRank() const { return rank; }
|
std::int32_t getRank() const { return rank; }
|
||||||
void *getData() const { return data; }
|
void *getData() const { return data; }
|
||||||
|
@ -169,6 +178,7 @@ private:
|
||||||
_(None) \
|
_(None) \
|
||||||
_(Bool) \
|
_(Bool) \
|
||||||
_(Int) \
|
_(Int) \
|
||||||
|
_(Float) \
|
||||||
_(Double)
|
_(Double)
|
||||||
|
|
||||||
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
||||||
|
@ -193,15 +203,23 @@ struct RtValue final {
|
||||||
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
||||||
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
||||||
bool isInt() const { return Tag::Int == tag; }
|
bool isInt() const { return Tag::Int == tag; }
|
||||||
bool toInt() const {
|
int64_t toInt() const {
|
||||||
assert(isInt());
|
assert(isInt());
|
||||||
return payload.asInt;
|
return payload.asInt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float
|
||||||
|
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
|
||||||
|
bool isFloat() const { return Tag::Float == tag; }
|
||||||
|
float toFloat() const {
|
||||||
|
assert(isFloat());
|
||||||
|
return payload.asFloat;
|
||||||
|
}
|
||||||
|
|
||||||
// Double
|
// Double
|
||||||
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
||||||
bool isDouble() const { return Tag::Double == tag; }
|
bool isDouble() const { return Tag::Double == tag; }
|
||||||
bool toDouble() const {
|
double toDouble() const {
|
||||||
assert(isDouble());
|
assert(isDouble());
|
||||||
return payload.asDouble;
|
return payload.asDouble;
|
||||||
}
|
}
|
||||||
|
@ -227,6 +245,11 @@ struct RtValue final {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scalar
|
||||||
|
bool isScalar() const {
|
||||||
|
return isBool() || isInt() || isFloat() || isDouble();
|
||||||
|
}
|
||||||
|
|
||||||
// RtValue (downcast)
|
// RtValue (downcast)
|
||||||
const RtValue &toRtValue() const { return *this; }
|
const RtValue &toRtValue() const { return *this; }
|
||||||
RtValue &toRtValue() { return *this; }
|
RtValue &toRtValue() { return *this; }
|
||||||
|
@ -298,6 +321,7 @@ private:
|
||||||
union Payload {
|
union Payload {
|
||||||
bool asBool;
|
bool asBool;
|
||||||
int64_t asInt;
|
int64_t asInt;
|
||||||
|
float asFloat;
|
||||||
double asDouble;
|
double asDouble;
|
||||||
void *asVoidPtr;
|
void *asVoidPtr;
|
||||||
};
|
};
|
||||||
|
@ -313,19 +337,72 @@ private:
|
||||||
// This is the main entry point that users interact with.
|
// This is the main entry point that users interact with.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
enum class ArgType : std::uint32_t {
|
||||||
|
kNone = 0,
|
||||||
|
kTensor,
|
||||||
|
kF32,
|
||||||
|
kF64,
|
||||||
|
};
|
||||||
|
StringRef getArgTypeAsStringRef(ArgType type);
|
||||||
|
|
||||||
|
// Maximum rank supported across the ABI boundary
|
||||||
|
constexpr static int kMaxRank = 6;
|
||||||
|
|
||||||
|
struct InputArgInfo {
|
||||||
|
// What type of argument this is
|
||||||
|
ArgType argType;
|
||||||
|
// Certain arg types also have an element type
|
||||||
|
ElementType elementType;
|
||||||
|
std::int32_t rank;
|
||||||
|
std::array<std::int32_t, kMaxRank> extents;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OutputArgInfo {
|
||||||
|
// What type of argument this is
|
||||||
|
ArgType argType;
|
||||||
|
// Certain arg types also have an element type
|
||||||
|
ElementType elementType;
|
||||||
|
std::int32_t rank;
|
||||||
|
std::array<std::int32_t, kMaxRank> extents;
|
||||||
|
// TODO(brycearden): Add checks for whether output buffers alias to input
|
||||||
|
// buffers and populate field(s) here indicating that case
|
||||||
|
};
|
||||||
|
|
||||||
|
// Maximum input or output arity.
|
||||||
|
constexpr static int kMaxArity = 20;
|
||||||
|
|
||||||
// Metadata for a particular function.
|
// Metadata for a particular function.
|
||||||
// TODO: Add arg types.
|
|
||||||
struct FunctionMetadata {
|
struct FunctionMetadata {
|
||||||
std::int32_t numInputs;
|
std::int32_t numInputs;
|
||||||
std::int32_t numOutputs;
|
std::int32_t numOutputs;
|
||||||
|
|
||||||
|
std::array<InputArgInfo, kMaxArity> inputArgInfos;
|
||||||
|
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Opaque forward declaration of module descriptor type. This is the type
|
// Opaque forward declaration of module descriptor type. This is the type
|
||||||
// created by the compiler in the module binary.
|
// created by the compiler in the module binary.
|
||||||
struct ModuleDescriptor;
|
struct ModuleDescriptor;
|
||||||
|
|
||||||
// Maximum input or output arity.
|
// Verifies that the input RtValue arg types match what the user provides
|
||||||
constexpr static int kMaxArity = 20;
|
// matches the types we expect from the descriptors emitted by the
|
||||||
|
// compiler.
|
||||||
|
//
|
||||||
|
// Returns failure if the input type(s) are not valid
|
||||||
|
LogicalResult checkRtValueArgTypes(const RtValue &value,
|
||||||
|
const InputArgInfo &info);
|
||||||
|
|
||||||
|
// Verifies that the input RtValue shapes matches what the user provides
|
||||||
|
// matches the types we expect from the descriptors emitted by the
|
||||||
|
// compiler.
|
||||||
|
//
|
||||||
|
// Returns failure if the input type(s) are not valid
|
||||||
|
LogicalResult checkRtValueShapes(const RtValue &value,
|
||||||
|
const InputArgInfo &info);
|
||||||
|
|
||||||
|
// Creates an RtValue of the right type from the output metadata
|
||||||
|
// provided by the compiled module
|
||||||
|
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
|
||||||
|
|
||||||
// Low-level invocation API. The number of inputs and outputs should be correct
|
// Low-level invocation API. The number of inputs and outputs should be correct
|
||||||
// and match the results of getMetadata.
|
// and match the results of getMetadata.
|
||||||
|
|
|
@ -54,6 +54,16 @@ static LogicalResult verify(FuncMetadataOp op) {
|
||||||
return op.emitError() << "must agree on number of inputs";
|
return op.emitError() << "must agree on number of inputs";
|
||||||
if (op.numOutputs() != func.getNumResults())
|
if (op.numOutputs() != func.getNumResults())
|
||||||
return op.emitError() << "must agree on number of outputs";
|
return op.emitError() << "must agree on number of outputs";
|
||||||
|
|
||||||
|
if (op.numInputs() > 0) {
|
||||||
|
if (op.numInputs() != op.inputArgTypes()->size()) {
|
||||||
|
return op.emitError() << "number of inputTypes must match number of inputs";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (op.numOutputs() > 0) {
|
||||||
|
if (op.numOutputs() != op.outputArgTypes()->size())
|
||||||
|
return op.emitError() << "number of outputTypes must match number of outputs";
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||||
#include "npcomp/RefBackend/RefBackend.h"
|
#include "npcomp/RefBackend/RefBackend.h"
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
using namespace refback;
|
using namespace refback;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using llvm::Error;
|
using llvm::Error;
|
||||||
|
@ -73,6 +75,22 @@ static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
||||||
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
|
||||||
|
static constexpr char *kDynamicDimAsString = "?";
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "(";
|
||||||
|
for (int i = 0; i < extents.size(); i++) {
|
||||||
|
if (extents[i] < 0)
|
||||||
|
ss << kDynamicDimAsString;
|
||||||
|
else
|
||||||
|
ss << extents[i];
|
||||||
|
if (i != extents.size() - 1)
|
||||||
|
ss << "x";
|
||||||
|
}
|
||||||
|
ss << ")";
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||||
JITModule::invoke(llvm::StringRef functionName,
|
JITModule::invoke(llvm::StringRef functionName,
|
||||||
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
||||||
|
@ -80,12 +98,47 @@ JITModule::invoke(llvm::StringRef functionName,
|
||||||
if (refbackrt::failed(refbackrt::getMetadata(
|
if (refbackrt::failed(refbackrt::getMetadata(
|
||||||
descriptor, toRefbackrt(functionName), metadata)))
|
descriptor, toRefbackrt(functionName), metadata)))
|
||||||
return make_string_error("unknown function: " + Twine(functionName));
|
return make_string_error("unknown function: " + Twine(functionName));
|
||||||
SmallVector<refbackrt::RtValue, 6> outputs(
|
SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
|
||||||
metadata.numOutputs);
|
|
||||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||||
return make_string_error("invoking '" + Twine(functionName) +
|
return make_string_error("invoking '" + Twine(functionName) +
|
||||||
"': expected " + Twine(metadata.numInputs) +
|
"': expected " + Twine(metadata.numInputs) +
|
||||||
" inputs");
|
" inputs");
|
||||||
|
|
||||||
|
// Verify user input types and shapes match what the compiler expects
|
||||||
|
for (int i = 0; i < metadata.numInputs; i++) {
|
||||||
|
auto &input = inputs[i];
|
||||||
|
auto &inputArgInfo = metadata.inputArgInfos[i];
|
||||||
|
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
|
||||||
|
return make_string_error(
|
||||||
|
"invoking '" + Twine(functionName) +
|
||||||
|
"': input argument type mismatch. actual (provided by user): " +
|
||||||
|
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
|
||||||
|
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
|
||||||
|
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
|
||||||
|
return make_string_error(
|
||||||
|
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
|
||||||
|
Twine(i) + "). " + "actual (provided by user): " +
|
||||||
|
stringifyShape(input.toTensor()->getExtents()) +
|
||||||
|
", expected (from compiler): " +
|
||||||
|
stringifyShape(refbackrt::ArrayRef<int32_t>(
|
||||||
|
inputArgInfo.extents.data(), inputArgInfo.rank)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the correct output RtValue based on FuncMetadata,
|
||||||
|
// which contains the arg types (scalar, Tensor, etc.), element types (only
|
||||||
|
// applicable if not scalar) and shapes (also only applicable if not scalar)
|
||||||
|
//
|
||||||
|
// Currently we have to give each RtValue an output type so that we know
|
||||||
|
// how to pack / unpack the outputs properly across the ABI boundary in
|
||||||
|
// refbackrt::invoke. As a result, we can't just rely on the default
|
||||||
|
// construction of each output argument type (otherwise RtValue will have
|
||||||
|
// Tag::kNone) currently without passing the ArgInfo structs down to the
|
||||||
|
// Runtime level, so we deal with the output type creation here.
|
||||||
|
for (int i = 0; i < metadata.numOutputs; i++) {
|
||||||
|
outputs[i] = std::move(
|
||||||
|
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]));
|
||||||
|
}
|
||||||
|
|
||||||
refbackrt::invoke(
|
refbackrt::invoke(
|
||||||
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
|
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
|
||||||
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
||||||
|
|
|
@ -34,25 +34,70 @@ using mlir::LLVM::LLVMVoidType;
|
||||||
// These correspond to the types in CompilerDataStructures.h
|
// These correspond to the types in CompilerDataStructures.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// MaxRank that the refbackrt ABI lowering is capable of handling
|
||||||
|
// NOTE: This parameter must stay consistent with
|
||||||
|
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
|
||||||
|
static constexpr int kMaxRank = 6;
|
||||||
|
|
||||||
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
|
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
|
||||||
return LLVMPointerType::get(IntegerType::get(context, 8));
|
return LLVMPointerType::get(IntegerType::get(context, 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
|
||||||
|
return LLVMPointerType::get(IntegerType::get(context, 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
|
||||||
|
return LLVMStructType::getLiteral(
|
||||||
|
context, {
|
||||||
|
// ArgType
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// ElementType
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// Rank
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// Extents
|
||||||
|
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||||
|
// IsStatic
|
||||||
|
// IntegerType::get(context, 32),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
|
||||||
|
return LLVMStructType::getLiteral(
|
||||||
|
context, {
|
||||||
|
// ArgType
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// ElementType
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// Rank
|
||||||
|
IntegerType::get(context, 32),
|
||||||
|
// Extents
|
||||||
|
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||||
|
// IsStatic
|
||||||
|
// IntegerType::get(context, 32),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Get the LLVM type for refbackrt::FuncDescriptor.
|
// Get the LLVM type for refbackrt::FuncDescriptor.
|
||||||
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
|
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
|
||||||
return LLVMStructType::getLiteral(context,
|
return LLVMStructType::getLiteral(
|
||||||
{
|
context, {
|
||||||
// Name length.
|
// Name length.
|
||||||
IntegerType::get(context, 32),
|
IntegerType::get(context, 32),
|
||||||
// Name chars.
|
// Name chars.
|
||||||
getInt8PointerType(context),
|
getInt8PointerType(context),
|
||||||
// Type-erased function pointer.
|
// Type-erased function pointer.
|
||||||
getInt8PointerType(context),
|
getInt8PointerType(context),
|
||||||
// Number of inputs.
|
// Number of inputs.
|
||||||
IntegerType::get(context, 32),
|
IntegerType::get(context, 32),
|
||||||
// Number of outputs.
|
// Number of outputs.
|
||||||
IntegerType::get(context, 32),
|
IntegerType::get(context, 32),
|
||||||
});
|
// Argument descriptors
|
||||||
|
LLVMPointerType::get(getInputDescriptorTy(context)),
|
||||||
|
// Result Descriptors
|
||||||
|
LLVMPointerType::get(getOutputDescriptorTy(context)),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the LLVM type for refbackrt::ModuleDescriptor.
|
// Get the LLVM type for refbackrt::ModuleDescriptor.
|
||||||
|
@ -92,8 +137,8 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
||||||
// TODO: Deduplicate strings.
|
// TODO: Deduplicate strings.
|
||||||
std::string msgNulTerminated = msg.getValue().str();
|
std::string msgNulTerminated = msg.getValue().str();
|
||||||
msgNulTerminated.push_back('\0');
|
msgNulTerminated.push_back('\0');
|
||||||
auto arrayTy = LLVMArrayType::get(
|
auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
|
||||||
IntegerType::get(module.getContext(), 8), msgNulTerminated.size());
|
msgNulTerminated.size());
|
||||||
OpBuilder::InsertionGuard guard(builder);
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
builder.setInsertionPointToStart(module.getBody());
|
builder.setInsertionPointToStart(module.getBody());
|
||||||
|
|
||||||
|
@ -129,9 +174,9 @@ public:
|
||||||
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
|
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
|
||||||
op.msgAttr(), rewriter, op.getLoc());
|
op.msgAttr(), rewriter, op.getLoc());
|
||||||
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
|
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
|
||||||
auto c0 = rewriter.create<LLVM::ConstantOp>(
|
auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
|
||||||
op.getLoc(), IntegerType::get(context, 32),
|
IntegerType::get(context, 32),
|
||||||
rewriter.getI32IntegerAttr(0));
|
rewriter.getI32IntegerAttr(0));
|
||||||
auto msg =
|
auto msg =
|
||||||
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
|
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
|
||||||
msgArray, ValueRange({c0, c0}));
|
msgArray, ValueRange({c0, c0}));
|
||||||
|
@ -181,16 +226,181 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||||
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
|
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
|
||||||
|
|
||||||
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
|
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
|
||||||
|
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
|
||||||
|
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
|
||||||
|
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
|
||||||
|
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
|
||||||
for (auto funcMetadata : funcMetadatas) {
|
for (auto funcMetadata : funcMetadatas) {
|
||||||
auto arrayTy =
|
auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
|
||||||
LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
|
funcMetadata.funcName().size());
|
||||||
funcMetadata.funcName().size());
|
|
||||||
std::string llvmSymbolName =
|
std::string llvmSymbolName =
|
||||||
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
|
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
|
||||||
auto global = builder.create<LLVM::GlobalOp>(
|
auto global = builder.create<LLVM::GlobalOp>(
|
||||||
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||||
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
|
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
|
||||||
globalsByName[funcMetadata.funcName()] = global;
|
globalsByName[funcMetadata.funcName()] = global;
|
||||||
|
|
||||||
|
// Create constants for the input / output shapes
|
||||||
|
if (funcMetadata.inputShapes().hasValue()) {
|
||||||
|
auto i32ArrayInputSymbolName =
|
||||||
|
(Twine("__npcomp_internal_constant_input_shapes_") +
|
||||||
|
funcMetadata.funcName())
|
||||||
|
.str();
|
||||||
|
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
|
||||||
|
auto inputI32ArrayTy =
|
||||||
|
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
|
||||||
|
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||||
|
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||||
|
i32ArrayInputSymbolName,
|
||||||
|
/*value=*/funcMetadata.inputShapes().getValue());
|
||||||
|
|
||||||
|
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (funcMetadata.outputShapes().hasValue()) {
|
||||||
|
auto i32ArrayOutputSymbolName =
|
||||||
|
(Twine("__npcomp_internal_constant_output_shapes_") +
|
||||||
|
funcMetadata.funcName())
|
||||||
|
.str();
|
||||||
|
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
|
||||||
|
auto outputI32ArrayTy =
|
||||||
|
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
|
||||||
|
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||||
|
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||||
|
i32ArrayOutputSymbolName,
|
||||||
|
/*value=*/funcMetadata.outputShapes().getValue());
|
||||||
|
|
||||||
|
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto updateDescriptor = [&](Value &descriptor, Value value,
|
||||||
|
std::initializer_list<int32_t> position) {
|
||||||
|
descriptor = builder.create<LLVM::InsertValueOp>(
|
||||||
|
loc, descriptor, value,
|
||||||
|
/*position=*/builder.getI32ArrayAttr(position));
|
||||||
|
};
|
||||||
|
auto updateDescriptorWithI32Attr =
|
||||||
|
[&](Value &descriptor, Attribute attr,
|
||||||
|
std::initializer_list<int32_t> position) {
|
||||||
|
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
|
||||||
|
updateDescriptor(descriptor, constant, position);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create global input descriptors
|
||||||
|
for (auto funcMetadata : funcMetadatas) {
|
||||||
|
std::string llvmInputSymbolName =
|
||||||
|
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
|
||||||
|
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
|
||||||
|
auto inputDescriptorArrayTy =
|
||||||
|
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
|
||||||
|
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||||
|
loc, inputDescriptorArrayTy, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
|
||||||
|
|
||||||
|
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||||
|
builder.getI32IntegerAttr(0));
|
||||||
|
|
||||||
|
Value inputDescriptorArray =
|
||||||
|
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
|
||||||
|
|
||||||
|
for (int i = 0; i < funcMetadata.numInputs(); i++) {
|
||||||
|
// Arg Type
|
||||||
|
if (!funcMetadata.inputArgTypes().hasValue())
|
||||||
|
funcMetadata.emitError()
|
||||||
|
<< "numInputs > 0 but there are no inputArgTypes?";
|
||||||
|
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||||
|
funcMetadata.inputArgTypes()->getValue(i),
|
||||||
|
{i, 0});
|
||||||
|
// Element Type
|
||||||
|
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||||
|
funcMetadata.inputElementTypes()->getValue(i),
|
||||||
|
{i, 1});
|
||||||
|
|
||||||
|
// Rank
|
||||||
|
// auto inputShapesType =
|
||||||
|
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
|
||||||
|
auto rank = funcMetadata.inputRanks()->getValue(i);
|
||||||
|
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
|
||||||
|
|
||||||
|
// Shape
|
||||||
|
// Each shape array is derived by offseting of kMaxRank * arg index
|
||||||
|
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||||
|
loc, inputShapesByName[funcMetadata.funcName()]);
|
||||||
|
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||||
|
loc, IntegerType::get(builder.getContext(), 32),
|
||||||
|
builder.getI32IntegerAttr(i * kMaxRank));
|
||||||
|
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||||
|
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||||
|
ValueRange({c0, cShapeOffset}));
|
||||||
|
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
|
||||||
|
|
||||||
|
inputDescriptorsByName[funcMetadata.funcName()] =
|
||||||
|
std::move(inputDescriptorArrayGlobal);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create global output descriptors
|
||||||
|
for (auto funcMetadata : funcMetadatas) {
|
||||||
|
std::string llvmOutputSymbolName =
|
||||||
|
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
|
||||||
|
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
|
||||||
|
auto outputDescriptorArrayTy =
|
||||||
|
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
|
||||||
|
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||||
|
loc, outputDescriptorArrayTy, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
|
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
|
||||||
|
|
||||||
|
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||||
|
builder.getI32IntegerAttr(0));
|
||||||
|
|
||||||
|
Value outputDescriptorArray =
|
||||||
|
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
|
||||||
|
|
||||||
|
for (int i = 0; i < funcMetadata.numOutputs(); i++) {
|
||||||
|
if (!funcMetadata.outputArgTypes().hasValue())
|
||||||
|
funcMetadata.emitError()
|
||||||
|
<< "numOutputs > 0 but there are no outputArgTypes?";
|
||||||
|
// Arg Type
|
||||||
|
updateDescriptorWithI32Attr(outputDescriptorArray,
|
||||||
|
funcMetadata.outputArgTypes()->getValue(i),
|
||||||
|
{i, 0});
|
||||||
|
// Element Type
|
||||||
|
updateDescriptorWithI32Attr(
|
||||||
|
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
|
||||||
|
{i, 1});
|
||||||
|
|
||||||
|
// Rank
|
||||||
|
// auto outputShapesType =
|
||||||
|
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
|
||||||
|
auto rank = funcMetadata.outputRanks()->getValue(i);
|
||||||
|
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
|
||||||
|
|
||||||
|
// Shapes
|
||||||
|
// Offset by kMaxRank * arg index
|
||||||
|
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||||
|
loc, outputShapesByName[funcMetadata.funcName()]);
|
||||||
|
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||||
|
loc, IntegerType::get(builder.getContext(), 32),
|
||||||
|
builder.getI32IntegerAttr(i * kMaxRank));
|
||||||
|
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||||
|
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||||
|
ValueRange({c0, cShapeOffset}));
|
||||||
|
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
|
||||||
|
|
||||||
|
outputDescriptorsByName[funcMetadata.funcName()] =
|
||||||
|
outputDescriptorArrayGlobal;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This must match FuncDescriptor in the runtime.
|
// This must match FuncDescriptor in the runtime.
|
||||||
|
@ -201,31 +411,23 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||||
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||||
"__npcomp_func_descriptors",
|
"__npcomp_func_descriptors",
|
||||||
/*value=*/Attribute());
|
/*value=*/Attribute());
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(builder);
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
|
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
|
||||||
|
|
||||||
|
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||||
|
builder.getI32IntegerAttr(0));
|
||||||
// Build the initializer.
|
// Build the initializer.
|
||||||
Value funcDescriptorArray =
|
Value funcDescriptorArray =
|
||||||
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
|
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
|
||||||
auto updateDescriptor = [&](Value value,
|
|
||||||
std::initializer_list<int32_t> position) {
|
|
||||||
funcDescriptorArray = builder.create<LLVM::InsertValueOp>(
|
|
||||||
loc, funcDescriptorArray, value,
|
|
||||||
/*position=*/builder.getI32ArrayAttr(position));
|
|
||||||
};
|
|
||||||
auto updateDescriptorWithI32Attr =
|
|
||||||
[&](Attribute attr, std::initializer_list<int32_t> position) {
|
|
||||||
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
|
|
||||||
updateDescriptor(constant, position);
|
|
||||||
};
|
|
||||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
|
||||||
builder.getI32IntegerAttr(0));
|
|
||||||
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
|
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
|
||||||
auto funcMetadata = funcMetadataAndIndex.value();
|
auto funcMetadata = funcMetadataAndIndex.value();
|
||||||
int32_t index = funcMetadataAndIndex.index();
|
int32_t index = funcMetadataAndIndex.index();
|
||||||
|
|
||||||
// Name length.
|
// Name length.
|
||||||
updateDescriptorWithI32Attr(
|
updateDescriptorWithI32Attr(
|
||||||
|
funcDescriptorArray,
|
||||||
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
|
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
|
||||||
|
|
||||||
// Name chars.
|
// Name chars.
|
||||||
|
@ -234,7 +436,7 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||||
auto funcNamePtr = builder.create<LLVM::GEPOp>(
|
auto funcNamePtr = builder.create<LLVM::GEPOp>(
|
||||||
loc, getInt8PointerType(builder.getContext()), funcNameArray,
|
loc, getInt8PointerType(builder.getContext()), funcNameArray,
|
||||||
ValueRange({c0, c0}));
|
ValueRange({c0, c0}));
|
||||||
updateDescriptor(funcNamePtr, {index, 1});
|
updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
|
||||||
|
|
||||||
// Function pointer.
|
// Function pointer.
|
||||||
//
|
//
|
||||||
|
@ -247,13 +449,31 @@ createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||||
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
|
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
|
||||||
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
|
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
|
||||||
loc, getInt8PointerType(builder.getContext()), funcAddress);
|
loc, getInt8PointerType(builder.getContext()), funcAddress);
|
||||||
updateDescriptor(typeErasedFuncAddress, {index, 2});
|
updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
|
||||||
|
|
||||||
// Number of inputs.
|
// Number of inputs.
|
||||||
updateDescriptorWithI32Attr(funcMetadata.numInputsAttr(), {index, 3});
|
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||||
|
funcMetadata.numInputsAttr(), {index, 3});
|
||||||
|
|
||||||
// Number of outputs.
|
// Number of outputs.
|
||||||
updateDescriptorWithI32Attr(funcMetadata.numOutputsAttr(), {index, 4});
|
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||||
|
funcMetadata.numOutputsAttr(), {index, 4});
|
||||||
|
|
||||||
|
// Input descriptors
|
||||||
|
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||||
|
loc, inputDescriptorsByName[funcMetadata.funcName()]);
|
||||||
|
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
|
||||||
|
inputDescriptorsArrayAddress);
|
||||||
|
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
|
||||||
|
|
||||||
|
// Output descriptors
|
||||||
|
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||||
|
loc, outputDescriptorsByName[funcMetadata.funcName()]);
|
||||||
|
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
|
||||||
|
outputDescriptorsArrayAddress);
|
||||||
|
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
|
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
|
||||||
|
@ -379,6 +599,16 @@ static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
|
||||||
/*memorySpace=*/0));
|
/*memorySpace=*/0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Type getDoubleType(MLIRContext *context) {
|
||||||
|
LLVMTypeConverter converter(context);
|
||||||
|
return converter.convertType(FloatType::getF64(context));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Type getFloatType(MLIRContext *context) {
|
||||||
|
LLVMTypeConverter converter(context);
|
||||||
|
return converter.convertType(FloatType::getF32(context));
|
||||||
|
}
|
||||||
|
|
||||||
// Writes out the logical results of the wrapper function through the void**
|
// Writes out the logical results of the wrapper function through the void**
|
||||||
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
|
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
|
||||||
// only supports a single return type (or void/no results), the logic here needs
|
// only supports a single return type (or void/no results), the logic here needs
|
||||||
|
@ -393,12 +623,18 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
|
||||||
return;
|
return;
|
||||||
Value result = callToWrapped.getResult(0);
|
Value result = callToWrapped.getResult(0);
|
||||||
auto ty = result.getType();
|
auto ty = result.getType();
|
||||||
|
|
||||||
// 1 logical result.
|
// 1 logical result.
|
||||||
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
|
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
|
||||||
Value addr =
|
Value addr =
|
||||||
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||||
builder.create<LLVM::StoreOp>(loc, result, addr);
|
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||||
return;
|
return;
|
||||||
|
} else if (ty == getFloatType(ty.getContext())) {
|
||||||
|
Value addr =
|
||||||
|
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||||
|
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
|
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
|
||||||
auto structType = ty.cast<LLVMStructType>();
|
auto structType = ty.cast<LLVMStructType>();
|
||||||
|
|
|
@ -21,6 +21,16 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
// Since input/output shapes are not hyper-rectangular we specify
|
||||||
|
// a maximum rank for each input shape such that shapes are padded
|
||||||
|
// out to kMaxRank at the ABI boundary. That way we can represent
|
||||||
|
// shapes using a traditional DenseElementsAttr.
|
||||||
|
//
|
||||||
|
// NOTE: When changing this parameter, also change the same `kMaxRank`
|
||||||
|
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
|
||||||
|
// stays consistent.
|
||||||
|
static constexpr int kMaxRank = 6;
|
||||||
|
|
||||||
// Get the type used to represent MemRefType `type` on ABI boundaries.
|
// Get the type used to represent MemRefType `type` on ABI boundaries.
|
||||||
// For convenience we do a cast to MemRefType internally.
|
// For convenience we do a cast to MemRefType internally.
|
||||||
static Type getABIMemrefType(Type type) {
|
static Type getABIMemrefType(Type type) {
|
||||||
|
@ -38,7 +48,99 @@ static bool expressibleWithRefbackrtABI(FunctionType type) {
|
||||||
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
|
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
|
||||||
return llvm::all_of(
|
return llvm::all_of(
|
||||||
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
||||||
[](Type t) { return t.isa<MemRefType>(); });
|
[](Type t) {
|
||||||
|
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
|
||||||
|
// Must stay aligned with CompilerDataStructures::ABIArgType enum
|
||||||
|
static uint32_t getIntReprForABIType(Type type) {
|
||||||
|
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
|
||||||
|
return 1;
|
||||||
|
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
|
||||||
|
switch (floatTy.getWidth()) {
|
||||||
|
case 32:
|
||||||
|
return 2;
|
||||||
|
case 64:
|
||||||
|
return 3;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported float bit width");
|
||||||
|
}
|
||||||
|
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
|
||||||
|
}
|
||||||
|
// assert(false && "couldn't get IntReprForABIType");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must stay aligned with CompilerDataStructures::ABIElementType enum
|
||||||
|
static uint32_t getIntReprForABIElementType(Type type) {
|
||||||
|
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
|
||||||
|
auto elemTy = shapedTy.getElementType();
|
||||||
|
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
|
||||||
|
switch (floatTy.getWidth()) {
|
||||||
|
case 32:
|
||||||
|
return 1;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported tensor element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int32_t, kMaxRank>
|
||||||
|
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
|
||||||
|
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
|
||||||
|
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||||
|
if (!shapedType.hasRank()) {
|
||||||
|
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = shapedType.getShape();
|
||||||
|
auto shapeRank = shapedType.getRank();
|
||||||
|
if (shapeRank <= maxRank) {
|
||||||
|
SmallVector<int32_t, kMaxRank> extendedShape;
|
||||||
|
// Push back all the values of the shape
|
||||||
|
for (auto extentAndIndex : llvm::enumerate(shape)) {
|
||||||
|
auto extent = extentAndIndex.value();
|
||||||
|
auto index = extentAndIndex.index();
|
||||||
|
if (shapedType.isDynamic(index)) {
|
||||||
|
extendedShape.push_back(-1);
|
||||||
|
} else {
|
||||||
|
extendedShape.push_back(extent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad whatever is left so we have even vectors
|
||||||
|
auto padRank = maxRank - shapeRank;
|
||||||
|
for (int i = 0; i < padRank; i++)
|
||||||
|
extendedShape.push_back(0xDEAD'BEEF);
|
||||||
|
|
||||||
|
return extendedShape;
|
||||||
|
} else {
|
||||||
|
assert(false && "unsupported rank");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Represent Scalar's as all 1's.
|
||||||
|
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t getRankForType(Type type) {
|
||||||
|
// Returns a rank of -1 if the tensor is unranked
|
||||||
|
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||||
|
return shapedType.hasRank() ? shapedType.getRank() : -1;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t hasStaticShape(Type type) {
|
||||||
|
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||||
|
return shapedType.hasStaticShape() ? 1 : 0;
|
||||||
|
}
|
||||||
|
// Assume scalars and non-shaped type things are static
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult createModuleMetadata(ModuleOp module) {
|
static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
|
@ -59,10 +161,131 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
}
|
}
|
||||||
// TODO: Add richer information here such as expected shapes and element
|
// TODO: Add richer information here such as expected shapes and element
|
||||||
// types.
|
// types.
|
||||||
builder.create<refbackrt::FuncMetadataOp>(
|
SmallVector<uint32_t, 6> inputABIArgTypes;
|
||||||
func.getLoc(), builder.getSymbolRefAttr(func.getName()),
|
SmallVector<uint32_t, 6> inputABIElementTypes;
|
||||||
builder.getI32IntegerAttr(func.getNumArguments()),
|
SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
|
||||||
builder.getI32IntegerAttr(func.getNumResults()));
|
SmallVector<uint32_t, 6> inputABIRanks;
|
||||||
|
// SmallVector<uint32_t, 6> inputIsStatic;
|
||||||
|
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
|
||||||
|
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
|
||||||
|
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
|
||||||
|
inputABIShapes.push_back(
|
||||||
|
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
|
||||||
|
inputABIRanks.push_back(getRankForType(inputArgType));
|
||||||
|
// inputIsStatic.push_back(hasStaticShape(inputArgType));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<uint32_t, 6> outputABIArgTypes;
|
||||||
|
SmallVector<uint32_t, 6> outputABIElementTypes;
|
||||||
|
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
|
||||||
|
SmallVector<uint32_t, 6> outputABIRanks;
|
||||||
|
SmallVector<uint32_t, 6> outputIsStatic;
|
||||||
|
for (const auto &outputArgType : func.getCallableResults()) {
|
||||||
|
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
|
||||||
|
outputABIElementTypes.push_back(
|
||||||
|
getIntReprForABIElementType(outputArgType));
|
||||||
|
outputABIShapes.push_back(
|
||||||
|
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
|
||||||
|
outputABIRanks.push_back(getRankForType(outputArgType));
|
||||||
|
// outputIsStatic.push_back(hasStaticShape(outputArgType));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto i32Type = builder.getIntegerType(32);
|
||||||
|
auto inputABIDataType =
|
||||||
|
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
|
||||||
|
auto inputABIElementType =
|
||||||
|
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
|
||||||
|
auto inputABIShapesType = RankedTensorType::get(
|
||||||
|
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
|
||||||
|
kMaxRank},
|
||||||
|
i32Type);
|
||||||
|
auto inputABIRanksType =
|
||||||
|
RankedTensorType::get(inputABIRanks.size(), i32Type);
|
||||||
|
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
|
||||||
|
// i32Type);
|
||||||
|
auto outputABIDataType =
|
||||||
|
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
|
||||||
|
auto outputABIElementType =
|
||||||
|
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
|
||||||
|
auto outputABIShapesType = RankedTensorType::get(
|
||||||
|
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
|
||||||
|
kMaxRank},
|
||||||
|
i32Type);
|
||||||
|
auto outputABIRanksType =
|
||||||
|
RankedTensorType::get(outputABIRanks.size(), i32Type);
|
||||||
|
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
|
||||||
|
// i32Type);
|
||||||
|
|
||||||
|
// TODO(brycearden): I'm sure there's a cleaner way to do this
|
||||||
|
auto flattenABIShapes =
|
||||||
|
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
|
||||||
|
SmallVector<int32_t, 32> ret;
|
||||||
|
for (auto &shape : shapes)
|
||||||
|
for (auto &dim : shape)
|
||||||
|
ret.push_back(dim);
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<NamedAttribute, 16> namedAttrs;
|
||||||
|
|
||||||
|
// Add attributes that are valid for every func (funcName, numInputs,
|
||||||
|
// numOutputs)
|
||||||
|
namedAttrs.push_back(
|
||||||
|
std::make_pair(Identifier::get("funcName", module.getContext()),
|
||||||
|
builder.getSymbolRefAttr(func.getName())));
|
||||||
|
namedAttrs.push_back(
|
||||||
|
std::make_pair(Identifier::get("numInputs", module.getContext()),
|
||||||
|
builder.getI32IntegerAttr(func.getNumArguments())));
|
||||||
|
namedAttrs.push_back(
|
||||||
|
std::make_pair(Identifier::get("numOutputs", module.getContext()),
|
||||||
|
builder.getI32IntegerAttr(func.getNumResults())));
|
||||||
|
|
||||||
|
if (inputABIArgTypes.size()) {
|
||||||
|
// Only add input information if there are inputs
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("inputArgTypes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(inputABIDataType,
|
||||||
|
llvm::makeArrayRef(inputABIArgTypes))));
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("inputElementTypes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(inputABIElementType,
|
||||||
|
llvm::makeArrayRef(inputABIElementTypes))));
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("inputRanks", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(inputABIRanksType,
|
||||||
|
llvm::makeArrayRef(inputABIRanks))));
|
||||||
|
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("inputShapes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(
|
||||||
|
inputABIShapesType,
|
||||||
|
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputABIArgTypes.size()) {
|
||||||
|
// Only add output information if there are outptus
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("outputArgTypes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(outputABIDataType,
|
||||||
|
llvm::makeArrayRef(outputABIArgTypes))));
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("outputElementTypes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(
|
||||||
|
outputABIElementType,
|
||||||
|
llvm::makeArrayRef(outputABIElementTypes))));
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("outputRanks", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(outputABIRanksType,
|
||||||
|
llvm::makeArrayRef(outputABIRanks))));
|
||||||
|
namedAttrs.push_back(std::make_pair(
|
||||||
|
Identifier::get("outputShapes", func.getContext()),
|
||||||
|
DenseIntElementsAttr::get(
|
||||||
|
outputABIShapesType,
|
||||||
|
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
|
||||||
|
ArrayRef<Value>{}, namedAttrs);
|
||||||
|
|
||||||
if (!expressibleWithRefbackrtABI(func.getType()))
|
if (!expressibleWithRefbackrtABI(func.getType()))
|
||||||
return func.emitError() << "func not expressible with refbackrt ABI";
|
return func.emitError() << "func not expressible with refbackrt ABI";
|
||||||
|
|
|
@ -23,6 +23,40 @@ namespace refbackrt {
|
||||||
// LowerToLLVM.cpp for more details.
|
// LowerToLLVM.cpp for more details.
|
||||||
typedef void ABIFunc(void **, void **);
|
typedef void ABIFunc(void **, void **);
|
||||||
|
|
||||||
|
enum class ABIArgType : std::uint32_t {
|
||||||
|
kNone = 0,
|
||||||
|
kMemref,
|
||||||
|
kF32,
|
||||||
|
kF64,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class ABIElementType : std::uint32_t {
|
||||||
|
kNone = 0,
|
||||||
|
kF32,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InputDescriptor {
|
||||||
|
ABIArgType abiType;
|
||||||
|
ABIElementType elementType;
|
||||||
|
|
||||||
|
std::int32_t rank;
|
||||||
|
std::int32_t* extents;
|
||||||
|
|
||||||
|
// TODO(brycearden): Change to bool at ABI boundary
|
||||||
|
// std::int32_t isStatic;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OutputDescriptor {
|
||||||
|
ABIArgType abiType;
|
||||||
|
ABIElementType elementType;
|
||||||
|
|
||||||
|
std::int32_t rank;
|
||||||
|
std::int32_t* extents;
|
||||||
|
|
||||||
|
// TODO(brycearden): Change to bool at ABI boundary
|
||||||
|
//std::int32_t isStatic;
|
||||||
|
};
|
||||||
|
|
||||||
struct FuncDescriptor {
|
struct FuncDescriptor {
|
||||||
// The length of the function name.
|
// The length of the function name.
|
||||||
std::int32_t nameLen;
|
std::int32_t nameLen;
|
||||||
|
@ -35,9 +69,9 @@ struct FuncDescriptor {
|
||||||
std::int32_t numInputs;
|
std::int32_t numInputs;
|
||||||
// The number of outputs of the function.
|
// The number of outputs of the function.
|
||||||
std::int32_t numOutputs;
|
std::int32_t numOutputs;
|
||||||
// TODO: Add arg/result descriptors and other metadata.
|
// TODO: Add shape checking to arg / result descriptor(s)
|
||||||
// With those descriptors we can do type and shape checking for each
|
InputDescriptor *inputDescriptors;
|
||||||
// argument.
|
OutputDescriptor *outputDescriptors;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The top-level entry point of the module metadata emitted by the
|
// The top-level entry point of the module metadata emitted by the
|
||||||
|
|
|
@ -118,12 +118,38 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
|
||||||
|
|
||||||
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case ElementType::NONE:
|
||||||
|
return 0;
|
||||||
case ElementType::F32:
|
case ElementType::F32:
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
llvm_unreachable("unsupported dtype");
|
llvm_unreachable("unsupported dtype");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
|
||||||
|
switch (type) {
|
||||||
|
case ElementType::NONE:
|
||||||
|
return "NONE";
|
||||||
|
case ElementType::F32:
|
||||||
|
return "F32";
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported element type string");
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
|
||||||
|
switch (type) {
|
||||||
|
case ArgType::kNone:
|
||||||
|
return "kNone";
|
||||||
|
case ArgType::kTensor:
|
||||||
|
return "kTensor";
|
||||||
|
case ArgType::kF32:
|
||||||
|
return "kF32";
|
||||||
|
case ArgType::kF64:
|
||||||
|
return "kF64";
|
||||||
|
}
|
||||||
|
llvm_unreachable("unsupported arg type string");
|
||||||
|
}
|
||||||
|
|
||||||
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
|
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
|
||||||
void *data) {
|
void *data) {
|
||||||
return Ref<Tensor>(createRaw(extents, type, data));
|
return Ref<Tensor>(createRaw(extents, type, data));
|
||||||
|
@ -192,10 +218,7 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
||||||
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
||||||
// more complex though (and maybe impossible given the current abstractions).
|
// more complex though (and maybe impossible given the current abstractions).
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
//
|
||||||
inputUnrankedMemrefs[i] =
|
|
||||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
|
||||||
}
|
|
||||||
// Create a type-erased list of "packed inputs" to pass to the
|
// Create a type-erased list of "packed inputs" to pass to the
|
||||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||||
// one LLVM/C ABI argument to the underlying function.
|
// one LLVM/C ABI argument to the underlying function.
|
||||||
|
@ -204,16 +227,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
// "explode" the unranked memref descriptors on the underlying function
|
// "explode" the unranked memref descriptors on the underlying function
|
||||||
// into separate arguments for the rank and pointer-to-descriptor.
|
// into separate arguments for the rank and pointer-to-descriptor.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
packedInputs[2 * i] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
auto idx = 2 * i;
|
||||||
packedInputs[2 * i + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
if (inputs[i].isTensor()) {
|
||||||
|
inputUnrankedMemrefs[i] =
|
||||||
|
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||||
|
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||||
|
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||||
|
} else if (inputs[i].isScalar()) {
|
||||||
|
packedInputs[idx] = ToVoidPtr(&inputs[i]);
|
||||||
|
} else {
|
||||||
|
assert(false && "unsupported input RtValue type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a type-erased list of "packed output" to pass to the
|
// Create a type-erased list of "packed output" to pass to the
|
||||||
// LLVM/C ABI wrapper function.
|
// LLVM/C ABI wrapper function.
|
||||||
//
|
//
|
||||||
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
||||||
// corresponds to a single UnrankedMemref (not "exploded").
|
// corresponds to a single UnrankedMemref (not "exploded").
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
if (outputs[i].isTensor()) {
|
||||||
|
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||||
|
} else if (outputs[i].isScalar()) {
|
||||||
|
packedOutputs[i] = ToVoidPtr(&outputs[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Actually invoke the function!
|
// Actually invoke the function!
|
||||||
|
@ -223,11 +260,15 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
// TODO: Avoid needing to make a deep copy.
|
// TODO: Avoid needing to make a deep copy.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
// TODO: Have compiler emit the element type in the metadata.
|
// TODO: Have compiler emit the element type in the metadata.
|
||||||
auto elementType = ElementType::F32;
|
if (outputs[i].isTensor()) {
|
||||||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
auto elementType = ElementType::F32;
|
||||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||||
elementType);
|
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||||
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
elementType);
|
||||||
|
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||||
|
} else if (outputs[i].isFloat()) {
|
||||||
|
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||||
|
@ -239,24 +280,30 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
|
|
||||||
// Free the output buffers.
|
// Free the output buffers.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
if (outputs[i].isRef()) {
|
||||||
// Multiple returned memrefs can point into the same underlying
|
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||||
// malloc allocation. Do a linear scan to see if any of the previously
|
// Multiple returned memrefs can point into the same underlying
|
||||||
// deallocated buffers already freed this pointer.
|
// malloc allocation. Do a linear scan to see if any of the previously
|
||||||
bool bufferNeedsFreeing = true;
|
// deallocated buffers already freed this pointer.
|
||||||
for (int j = 0; j < i; j++) {
|
bool bufferNeedsFreeing = true;
|
||||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
for (int j = 0; j < i; j++) {
|
||||||
bufferNeedsFreeing = false;
|
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||||
|
bufferNeedsFreeing = false;
|
||||||
|
}
|
||||||
|
if (!bufferNeedsFreeing)
|
||||||
|
std::free(allocatedPtr);
|
||||||
}
|
}
|
||||||
if (!bufferNeedsFreeing)
|
|
||||||
std::free(allocatedPtr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Free the input buffers.
|
// Free the input buffers.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
|
if (!inputs[i].isRef())
|
||||||
|
continue;
|
||||||
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||||
bool bufferNeedsFreeing = true;
|
bool bufferNeedsFreeing = true;
|
||||||
for (int j = 0, je = outputs.size(); j < je; j++) {
|
for (int j = 0, je = outputs.size(); j < je; j++) {
|
||||||
|
if (!outputs[j].isRef())
|
||||||
|
continue;
|
||||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||||
bufferNeedsFreeing = false;
|
bufferNeedsFreeing = false;
|
||||||
}
|
}
|
||||||
|
@ -274,6 +321,8 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
|
|
||||||
// Free the output descriptors.
|
// Free the output descriptors.
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||||
|
if (!outputs[i].isRef())
|
||||||
|
continue;
|
||||||
// The LLVM lowering guarantees that each returned unranked memref
|
// The LLVM lowering guarantees that each returned unranked memref
|
||||||
// descriptor is separately malloc'ed, so no need to do anything special
|
// descriptor is separately malloc'ed, so no need to do anything special
|
||||||
// like we had to do for the allocatedPtr's.
|
// like we had to do for the allocatedPtr's.
|
||||||
|
@ -281,10 +330,81 @@ void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
}
|
}
|
||||||
// Free the input descriptors.
|
// Free the input descriptors.
|
||||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||||
|
if (!inputs[i].isRef())
|
||||||
|
continue;
|
||||||
std::free(inputUnrankedMemrefs[i].descriptor);
|
std::free(inputUnrankedMemrefs[i].descriptor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static InputArgInfo
|
||||||
|
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
|
||||||
|
InputArgInfo ret;
|
||||||
|
|
||||||
|
// Set arg / element types accordingly
|
||||||
|
switch (inputDescriptor.abiType) {
|
||||||
|
case ABIArgType::kNone:
|
||||||
|
ret.argType = ArgType::kNone;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kMemref:
|
||||||
|
ret.argType = ArgType::kTensor;
|
||||||
|
ret.elementType = ElementType::F32;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kF32:
|
||||||
|
ret.argType = ArgType::kF32;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kF64:
|
||||||
|
ret.argType = ArgType::kF64;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "need to update external internal map");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract shape information
|
||||||
|
ret.rank = inputDescriptor.rank;
|
||||||
|
for (int i = 0; i < inputDescriptor.rank; i++) {
|
||||||
|
ret.extents[i] = inputDescriptor.extents[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OutputArgInfo
|
||||||
|
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
|
||||||
|
OutputArgInfo ret;
|
||||||
|
|
||||||
|
// Set arg / element types accordingly
|
||||||
|
switch (outputDescriptor.abiType) {
|
||||||
|
case ABIArgType::kNone:
|
||||||
|
ret.argType = ArgType::kNone;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kMemref:
|
||||||
|
ret.argType = ArgType::kTensor;
|
||||||
|
ret.elementType = ElementType::F32;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kF32:
|
||||||
|
ret.argType = ArgType::kF32;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
case ABIArgType::kF64:
|
||||||
|
ret.argType = ArgType::kF64;
|
||||||
|
ret.elementType = ElementType::NONE;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "need to update external internal map");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract shape information
|
||||||
|
ret.rank = outputDescriptor.rank;
|
||||||
|
for (int i = 0; i < outputDescriptor.rank; i++) {
|
||||||
|
ret.extents[i] = outputDescriptor.extents[i];
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||||
StringRef functionName,
|
StringRef functionName,
|
||||||
FunctionMetadata &outMetadata) {
|
FunctionMetadata &outMetadata) {
|
||||||
|
@ -293,5 +413,107 @@ LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||||
return failure();
|
return failure();
|
||||||
outMetadata.numInputs = descriptor->numInputs;
|
outMetadata.numInputs = descriptor->numInputs;
|
||||||
outMetadata.numOutputs = descriptor->numOutputs;
|
outMetadata.numOutputs = descriptor->numOutputs;
|
||||||
|
|
||||||
|
for (int i = 0; i < descriptor->numInputs; i++) {
|
||||||
|
outMetadata.inputArgInfos[i] =
|
||||||
|
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < descriptor->numOutputs; i++) {
|
||||||
|
outMetadata.outputArgInfos[i] =
|
||||||
|
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
|
||||||
|
const InputArgInfo &info) {
|
||||||
|
if (value.isTensor()) {
|
||||||
|
auto refTensor = value.toTensor();
|
||||||
|
|
||||||
|
// Don't bother checking shapes for unranked tensors
|
||||||
|
if (info.rank < 0)
|
||||||
|
return success();
|
||||||
|
|
||||||
|
if (refTensor->getRank() != info.rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto tensorExtents = refTensor->getExtents();
|
||||||
|
for (int i = 0; i < info.rank; i++) {
|
||||||
|
// If a dimension is dynamic, it is encoded as extent = -1
|
||||||
|
// and we should skip checking over that dimension
|
||||||
|
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
|
||||||
|
const InputArgInfo &info) {
|
||||||
|
// Generic checks based on argType(s)
|
||||||
|
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
|
||||||
|
(value.isFloat() && info.argType != ArgType::kF32))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (value.isRef()) {
|
||||||
|
// Will need special error checking for ref-counted types
|
||||||
|
// Currently only f32 tensors are supported
|
||||||
|
if (value.isTensor()) {
|
||||||
|
auto refTensor = value.toTensor();
|
||||||
|
if (refTensor->getElementType() != ElementType::F32)
|
||||||
|
return failure();
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported input type checking for Ref type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
|
||||||
|
constexpr int32_t kDynamicConstantShape = 100;
|
||||||
|
switch (info.argType) {
|
||||||
|
case ArgType::kTensor: {
|
||||||
|
// HACK: for dynamic dims the shape will be negative, so for now we are
|
||||||
|
// just going to create a tensor of size kDynamicConstantShape
|
||||||
|
std::array<int32_t, kMaxRank> tensorShape;
|
||||||
|
for (int i = 0; i < info.rank; i++) {
|
||||||
|
tensorShape[i] =
|
||||||
|
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
|
||||||
|
}
|
||||||
|
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
|
||||||
|
int numel = 1;
|
||||||
|
for (int i = 0; i < info.rank; i++)
|
||||||
|
numel *= shape[i];
|
||||||
|
|
||||||
|
void *data;
|
||||||
|
switch (info.elementType) {
|
||||||
|
case ElementType::F32: {
|
||||||
|
auto byteSize = numel * sizeof(float);
|
||||||
|
data = static_cast<void *>(aligned_alloc(32, byteSize));
|
||||||
|
memset(data, 0, byteSize);
|
||||||
|
return RtValue(Tensor::create(shape, ElementType::F32, data));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: { assert(false && "unknown output tensor type"); }
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Tensor::create function will malloc and memcpy the data
|
||||||
|
// into the Tensor object, so after we need to free our
|
||||||
|
// temporary data buffer
|
||||||
|
assert(data && "data ptr must exist");
|
||||||
|
auto refTensor = Tensor::create(shape, ElementType::F32, data);
|
||||||
|
free(data);
|
||||||
|
return RtValue(refTensor);
|
||||||
|
}
|
||||||
|
case ArgType::kF32: {
|
||||||
|
return RtValue(-20.0f);
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
assert(false && "Don't know how to handle this artType");
|
||||||
|
return RtValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,9 +3,17 @@
|
||||||
// CHECK: refbackrt.module_metadata
|
// CHECK: refbackrt.module_metadata
|
||||||
refbackrt.module_metadata {
|
refbackrt.module_metadata {
|
||||||
// CHECK: refbackrt.func_metadata
|
// CHECK: refbackrt.func_metadata
|
||||||
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
// TODO(brycearden): Encode unranked memrefs in the ABI
|
||||||
|
refbackrt.func_metadata {
|
||||||
|
funcName = @f,
|
||||||
|
numInputs = 1 : i32,
|
||||||
|
numOutputs = 0 : i32,
|
||||||
|
inputArgTypes = dense<1> : tensor<1xi32>,
|
||||||
|
inputElementTypes = dense<1> : tensor<1xi32>,
|
||||||
|
inputRanks = dense<-1> : tensor<1xi32>,
|
||||||
|
inputShapes = dense<1> : tensor<4xi32>}
|
||||||
}
|
}
|
||||||
|
|
||||||
func @f(%arg0: memref<*xf32>) {
|
func @f(%arg0: tensor<*xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,165 +0,0 @@
|
||||||
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// Test input/output arg marshaling.
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results2(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
|
||||||
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_12]][0 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
|
||||||
// CHECK: llvm.store %[[VAL_17]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: %[[VAL_18:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_19:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_18]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_20:.*]] = llvm.load %[[VAL_19]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_21:.*]] = llvm.bitcast %[[VAL_20]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_12]][1 : i32] : !llvm.struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>
|
|
||||||
// CHECK: llvm.store %[[VAL_22]], %[[VAL_21]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: llvm.return
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_12:.*]] = llvm.call @inputs1results1(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
|
||||||
// CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_13]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_16:.*]] = llvm.bitcast %[[VAL_15]] : !llvm.ptr<i8> to !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: llvm.store %[[VAL_12]], %[[VAL_16]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
|
|
||||||
// CHECK: llvm.return
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
/// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.load %[[VAL_3]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.bitcast %[[VAL_4]] : !llvm.ptr<i8> to !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr<i64>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_7]]] : (!llvm.ptr<ptr<i8>>, i32) -> !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.load %[[VAL_8]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr<i8> to !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: %[[VAL_11:.*]] = llvm.load %[[VAL_10]] : !llvm.ptr<ptr<i8>>
|
|
||||||
// CHECK: llvm.call @inputs1results0(%[[VAL_6]], %[[VAL_11]]) : (i64, !llvm.ptr<i8>) -> ()
|
|
||||||
// CHECK: llvm.return
|
|
||||||
// CHECK: }
|
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
|
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
|
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
|
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
|
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(15 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_0]][0 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_9]][0 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_11]][0 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(15 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_15:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_13]][1 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
|
|
||||||
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
|
||||||
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_23:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_21]][1 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_24:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_25:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_23]][1 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(15 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_27:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_25]][2 : i32, 0 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
|
|
||||||
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
|
||||||
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
|
||||||
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_35:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_33]][2 : i32, 3 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_36:.*]] = llvm.mlir.constant(2 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_37:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_35]][2 : i32, 4 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: llvm.return %[[VAL_37]] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.mlir.global external constant @_mlir___npcomp_module_descriptor() : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> {
|
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(3 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomp_func_descriptors : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>>
|
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>> to !llvm.ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
|
||||||
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
|
||||||
// CHECK: }
|
|
||||||
|
|
||||||
refbackrt.module_metadata {
|
|
||||||
refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
|
||||||
refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
|
||||||
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
|
|
||||||
}
|
|
||||||
|
|
||||||
func @inputs1results0(%arg0: memref<*xf32>) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func @inputs1results1(%arg0: memref<*xf32>) -> memref<*xf32> {
|
|
||||||
return %arg0 : memref<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @inputs1results2(%arg0: memref<*xf32>) -> (memref<*xf32>, memref<*xf32>) {
|
|
||||||
return %arg0, %arg0 : memref<*xf32>, memref<*xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Test emission of compiler runtime functions.
|
|
||||||
|
|
||||||
// CHECK: llvm.mlir.global internal constant @[[STRSYM:.*]]("msg\00")
|
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(i1, !llvm.ptr<i8>)
|
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @calls_abort_if(
|
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: i1) {
|
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @[[STRSYM]] : !llvm.ptr<array<4 x i8>>
|
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<4 x i8>>, i32, i32) -> !llvm.ptr<i8>
|
|
||||||
// CHECK: llvm.call @__npcomp_compiler_rt_abort_if(%[[VAL_3:.*]], %[[VAL_2]]) : (i1, !llvm.ptr<i8>) -> ()
|
|
||||||
// CHECK: llvm.return
|
|
||||||
|
|
||||||
func @calls_abort_if(%arg0: i1) {
|
|
||||||
refbackrt.abort_if %arg0, "msg"
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -3,8 +3,14 @@
|
||||||
// Test module metadata.
|
// Test module metadata.
|
||||||
|
|
||||||
// CHECK: refbackrt.module_metadata
|
// CHECK: refbackrt.module_metadata
|
||||||
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
|
// CHECK-NEXT: refbackrt.func_metadata
|
||||||
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
|
// CHECK-SAME: funcName = @f_2inputs_0outputs
|
||||||
|
// CHECK-SAME: numInputs = 2
|
||||||
|
// CHECK-SAME: numOutputs = 0
|
||||||
|
// CHECK-NEXT: refbackrt.func_metadata
|
||||||
|
// CHECK-SAME: funcName = @f_1input_2outputs
|
||||||
|
// CHECK-SAME: numInputs = 1
|
||||||
|
// CHECK-SAME: numOutputs = 2
|
||||||
|
|
||||||
// This function only exists to test its metadata above.
|
// This function only exists to test its metadata above.
|
||||||
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
||||||
|
|
|
@ -9,4 +9,3 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
// RUN: not npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke invalid_input_shape \
|
||||||
|
// RUN: -arg-value="dense<1.0> : tensor<2x2x2x2xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s -check-prefix=ARG0-INVALID
|
||||||
|
|
||||||
|
// RUN: not npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke invalid_input_shape_arg1 \
|
||||||
|
// RUN: -arg-value="dense<1.0> : tensor<1x2x5xf32>" \
|
||||||
|
// RUN: -arg-value="dense<1.0> : tensor<1x2x10xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s -check-prefix=ARG1-INVALID
|
||||||
|
|
||||||
|
// ARG0-INVALID: invoking 'invalid_input_shape': input shape mismatch (%arg0).
|
||||||
|
// ARG0-INVALID-SAME: actual (provided by user): (2x2x2x2)
|
||||||
|
// ARG0-INVALID-SAME: expected (from compiler): (1x2x3x4)
|
||||||
|
func @invalid_input_shape(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||||
|
return %arg0: tensor<1x2x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ARG1-INVALID: invoking 'invalid_input_shape_arg1': input shape mismatch (%arg1)
|
||||||
|
// ARG1-INVALID-SAME: actual (provided by user): (1x2x10)
|
||||||
|
// ARG1-INVALID-SAME: expected (from compiler): (1x4x?)
|
||||||
|
func @invalid_input_shape_arg1(%arg0: tensor<1x2x?xf32>, %arg1: tensor<1x4x?xf32>) {
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
// RUN: not npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke expects_one_tensor \
|
||||||
|
// RUN: -arg-value="1.0 : f32" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
|
||||||
|
// CHECK-SAME: actual (provided by user): Float
|
||||||
|
// CHECK-SAME: expected (from compiler): kTensor
|
||||||
|
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = tcf.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
|
@ -2,10 +2,21 @@
|
||||||
// RUN: -invoke scalar \
|
// RUN: -invoke scalar \
|
||||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
||||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
// RUN: | FileCheck %s
|
// RUN: | FileCheck %s --check-prefix=SCALAR
|
||||||
|
|
||||||
// CHECK: output #0: dense<2.000000e+00> : tensor<f32>
|
// RUN: npcomp-run-mlir %s \
|
||||||
|
// RUN: -invoke scalar_arg \
|
||||||
|
// RUN: -arg-value="2.5 : f32" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s --check-prefix=SCALAR_ARG
|
||||||
|
|
||||||
|
// SCALAR: output #0: dense<2.000000e+00> : tensor<f32>
|
||||||
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
func @scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
%0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
%0 = tcf.add %arg0, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SCALAR_ARG: output #0: 2.500000e+00 : f32
|
||||||
|
func @scalar_arg(%arg0: f32) -> f32 {
|
||||||
|
return %arg0 : f32
|
||||||
|
}
|
|
@ -58,6 +58,14 @@ convertAttrToTensor(Attribute attr) {
|
||||||
return make_string_error("unhandled argument");
|
return make_string_error("unhandled argument");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Expected<float> convertAttrToFloat(Attribute attr) {
|
||||||
|
auto type = attr.getType().dyn_cast<FloatType>();
|
||||||
|
if (!type)
|
||||||
|
return make_string_error("converting an argument to float that is not a FloatType");
|
||||||
|
auto floatAttr = attr.dyn_cast<FloatAttr>();
|
||||||
|
return floatAttr.getValue().convertToFloat();
|
||||||
|
}
|
||||||
|
|
||||||
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
||||||
createInputs(ArrayRef<StringRef> argValues) {
|
createInputs(ArrayRef<StringRef> argValues) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
|
@ -66,12 +74,22 @@ createInputs(ArrayRef<StringRef> argValues) {
|
||||||
auto attr = parseAttribute(argValue, &context);
|
auto attr = parseAttribute(argValue, &context);
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
||||||
// TODO(brycearden): Handle multiple input types
|
|
||||||
auto expectedTensor = convertAttrToTensor(attr);
|
auto attrType = attr.getType();
|
||||||
if (!expectedTensor)
|
|
||||||
return expectedTensor.takeError();
|
if (attrType.isa<RankedTensorType>()) {
|
||||||
ret.push_back(std::move(*expectedTensor));
|
auto expectedTensor = convertAttrToTensor(attr);
|
||||||
|
if (!expectedTensor)
|
||||||
|
return expectedTensor.takeError();
|
||||||
|
ret.push_back(std::move(*expectedTensor));
|
||||||
|
} else if (attrType.isa<FloatType>()) {
|
||||||
|
auto expectedFloat = convertAttrToFloat(attr);
|
||||||
|
if (!expectedFloat)
|
||||||
|
return expectedFloat.takeError();
|
||||||
|
ret.push_back(refbackrt::RtValue(*expectedFloat));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,34 +110,40 @@ static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor
|
||||||
return RankedTensorType::get(extents, elementType);
|
return RankedTensorType::get(extents, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor,
|
static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
|
||||||
Builder &builder) {
|
Builder &builder) {
|
||||||
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
if (value.isTensor()) {
|
||||||
switch (tensor.getElementType()) {
|
auto& tensor = *(value.toTensor());
|
||||||
case refbackrt::ElementType::F32: {
|
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
||||||
SmallVector<float, 100> values;
|
switch (tensor.getElementType()) {
|
||||||
auto *basePtr = tensor.getData<float>();
|
case refbackrt::ElementType::F32: {
|
||||||
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
SmallVector<float, 100> values;
|
||||||
values.push_back(basePtr[i]);
|
auto *basePtr = tensor.getData<float>();
|
||||||
return DenseFPElementsAttr::get(type, values);
|
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
||||||
}
|
values.push_back(basePtr[i]);
|
||||||
|
return DenseFPElementsAttr::get(type, values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (value.isFloat()) {
|
||||||
|
return builder.getF32FloatAttr(value.toFloat());
|
||||||
|
} else {
|
||||||
|
assert(false && "could not convert value to mlir attribute");
|
||||||
}
|
}
|
||||||
llvm_unreachable("unsupported dtype");
|
llvm_unreachable("unsupported dtype");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
|
static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
Builder builder(&context);
|
Builder builder(&context);
|
||||||
auto attr = convertToMLIRAttribute(tensor, builder);
|
auto attr = convertToMLIRAttribute(value, builder);
|
||||||
attr.print(os);
|
attr.print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
||||||
llvm::raw_ostream &os) {
|
llvm::raw_ostream &os) {
|
||||||
for (auto output : llvm::enumerate(outputs)) {
|
for (auto output : llvm::enumerate(outputs)) {
|
||||||
assert(output.value().isTensor() && "only tensor outputs are supported.");
|
|
||||||
os << "output #" << output.index() << ": ";
|
os << "output #" << output.index() << ": ";
|
||||||
printOutput(*output.value().toTensor().get(), os);
|
printOutput(output.value(), os);
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -150,9 +174,11 @@ Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
|
||||||
auto expectedInputs = createInputs(argValues);
|
auto expectedInputs = createInputs(argValues);
|
||||||
if (!expectedInputs)
|
if (!expectedInputs)
|
||||||
return expectedInputs.takeError();
|
return expectedInputs.takeError();
|
||||||
|
|
||||||
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
|
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
|
||||||
if (!expectedOutputs)
|
if (!expectedOutputs)
|
||||||
return expectedOutputs.takeError();
|
return expectedOutputs.takeError();
|
||||||
|
|
||||||
auto outputs = std::move(*expectedOutputs);
|
auto outputs = std::move(*expectedOutputs);
|
||||||
printOutputs(outputs, llvm::outs());
|
printOutputs(outputs, llvm::outs());
|
||||||
llvm::outs() << "SUCCESS\n";
|
llvm::outs() << "SUCCESS\n";
|
||||||
|
|
Loading…
Reference in New Issue