2020-10-02 09:59:58 +08:00
|
|
|
//===- func_builder.cpp ---------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// This file is licensed under a pytorch-style license
|
|
|
|
// See frontends/pytorch/LICENSE for license information.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "func_builder.h"
|
|
|
|
|
2020-12-12 06:43:38 +08:00
|
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
|
|
#include "mlir-c/BuiltinTypes.h"
|
2020-11-21 09:03:23 +08:00
|
|
|
#include "mlir-c/Diagnostics.h"
|
2020-10-02 09:59:58 +08:00
|
|
|
#include "npcomp-c/Types.h"
|
|
|
|
|
|
|
|
using namespace torch_mlir;
|
|
|
|
|
2020-10-06 14:21:21 +08:00
|
|
|
static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
|
|
|
MlirAttribute value) {
|
|
|
|
OperationStateHolder s("std.constant", loc);
|
2020-11-24 11:20:26 +08:00
|
|
|
MlirNamedAttribute valueAttr =
|
|
|
|
mlirNamedAttributeGet(toMlirStringRef("value"), value);
|
2020-10-17 08:38:07 +08:00
|
|
|
mlirOperationStateAddResults(s, 1, &type);
|
|
|
|
mlirOperationStateAddAttributes(s, 1, &valueAttr);
|
2020-10-06 14:21:21 +08:00
|
|
|
return s.createOperation();
|
|
|
|
}
|
|
|
|
|
2020-11-21 09:03:23 +08:00
|
|
|
KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc,
|
2020-12-15 00:42:42 +08:00
|
|
|
const std::string &kernelName,
|
2020-11-21 09:03:23 +08:00
|
|
|
const c10::FunctionSchema &schema)
|
|
|
|
: context(context), loc(loc), state("torch.kernel_call", loc),
|
2020-12-15 00:42:42 +08:00
|
|
|
schema(schema) {
|
2020-11-21 09:03:23 +08:00
|
|
|
(void)this->context; // Preserve for future.
|
|
|
|
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
2020-11-25 05:10:27 +08:00
|
|
|
toMlirStringRef("kernelName"),
|
2020-12-12 06:43:38 +08:00
|
|
|
mlirStringAttrGet(
|
|
|
|
context, mlirStringRefCreate(kernelName.data(), kernelName.size())));
|
2020-11-21 09:03:23 +08:00
|
|
|
mlirOperationStateAddAttributes(state, 1, &kernelNameAttr);
|
|
|
|
addSchemaAttrs();
|
|
|
|
}
|
|
|
|
|
|
|
|
void KernelCallBuilder::addSchemaAttrs() {
|
|
|
|
// Map the op schema to the kernel_call attributes:
|
|
|
|
// sigArgTypes
|
|
|
|
// sigRetTypes
|
|
|
|
// sigIsVararg
|
|
|
|
// sigIsVarret
|
|
|
|
// sigIsMutable
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirNamedAttribute> attrs;
|
2020-11-24 11:20:26 +08:00
|
|
|
attrs.push_back(
|
|
|
|
mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"),
|
|
|
|
mlirBoolAttrGet(context, schema.is_mutable())));
|
|
|
|
attrs.push_back(
|
|
|
|
mlirNamedAttributeGet(toMlirStringRef("sigIsVararg"),
|
|
|
|
mlirBoolAttrGet(context, schema.is_vararg())));
|
|
|
|
attrs.push_back(
|
|
|
|
mlirNamedAttributeGet(toMlirStringRef("sigIsVarret"),
|
|
|
|
mlirBoolAttrGet(context, schema.is_varret())));
|
2020-11-21 09:03:23 +08:00
|
|
|
|
|
|
|
// Arg types.
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirAttribute> args;
|
2020-11-21 09:03:23 +08:00
|
|
|
for (auto &arg : schema.arguments()) {
|
|
|
|
const std::string &typeStr = arg.type()->str();
|
2020-12-12 06:43:38 +08:00
|
|
|
args.push_back(mlirStringAttrGet(
|
|
|
|
context, mlirStringRefCreate(typeStr.data(), typeStr.size())));
|
2020-11-21 09:03:23 +08:00
|
|
|
}
|
|
|
|
attrs.push_back(mlirNamedAttributeGet(
|
2020-11-24 11:20:26 +08:00
|
|
|
toMlirStringRef("sigArgTypes"),
|
|
|
|
mlirArrayAttrGet(context, args.size(), args.data())));
|
2020-11-21 09:03:23 +08:00
|
|
|
|
|
|
|
// Return types.
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirAttribute> returns;
|
2020-11-21 09:03:23 +08:00
|
|
|
for (auto &ret : schema.returns()) {
|
|
|
|
const std::string &typeStr = ret.type()->str();
|
2020-12-12 06:43:38 +08:00
|
|
|
returns.push_back(mlirStringAttrGet(
|
|
|
|
context, mlirStringRefCreate(typeStr.data(), typeStr.size())));
|
2020-11-21 09:03:23 +08:00
|
|
|
}
|
|
|
|
attrs.push_back(mlirNamedAttributeGet(
|
2020-11-25 05:10:27 +08:00
|
|
|
toMlirStringRef("sigRetTypes"),
|
2020-11-21 09:03:23 +08:00
|
|
|
mlirArrayAttrGet(context, returns.size(), returns.data())));
|
|
|
|
|
|
|
|
// Add attrs to op.
|
|
|
|
mlirOperationStateAddAttributes(state, attrs.size(), attrs.data());
|
|
|
|
}
|
|
|
|
|
|
|
|
void KernelCallBuilder::addOperand(MlirValue operand) {
|
|
|
|
mlirOperationStateAddOperands(state, 1, &operand);
|
|
|
|
}
|
|
|
|
|
|
|
|
void KernelCallBuilder::addResultType(MlirType resultType) {
|
|
|
|
mlirOperationStateAddResults(state, 1, &resultType);
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirOperation KernelCallBuilder::create() { return state.createOperation(); }
|
|
|
|
|
|
|
|
MlirType TypeMapper::mapFromTorchScalarType(c10::ScalarType scalarType) {
|
|
|
|
auto type = rawMapFromTorchScalarType(scalarType);
|
|
|
|
if (mlirTypeIsNull(type)) {
|
|
|
|
std::stringstream message;
|
|
|
|
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
|
|
|
throw std::invalid_argument(message.str());
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirType TypeMapper::mapFromTorchScalarType(MlirLocation loc,
|
|
|
|
c10::ScalarType scalarType) {
|
|
|
|
auto type = rawMapFromTorchScalarType(scalarType);
|
|
|
|
if (mlirTypeIsNull(type)) {
|
|
|
|
std::stringstream message;
|
|
|
|
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
|
|
|
mlirEmitError(loc, message.str().c_str());
|
|
|
|
}
|
|
|
|
return type;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
2020-10-02 09:59:58 +08:00
|
|
|
using c10::ScalarType;
|
|
|
|
switch (scalarType) {
|
|
|
|
case ScalarType::Byte:
|
2020-10-06 14:21:21 +08:00
|
|
|
// TODO: convert to mlirIntegerTypeUnsignedGet once supported.
|
|
|
|
return mlirIntegerTypeGet(context, 8);
|
2020-10-02 09:59:58 +08:00
|
|
|
case ScalarType::Char:
|
2020-10-06 14:21:21 +08:00
|
|
|
return mlirIntegerTypeGet(context, 8);
|
2020-10-02 09:59:58 +08:00
|
|
|
case ScalarType::Short:
|
2020-10-06 14:21:21 +08:00
|
|
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
|
|
return mlirIntegerTypeGet(context, 16);
|
2020-10-02 09:59:58 +08:00
|
|
|
case ScalarType::Int:
|
2020-10-06 14:21:21 +08:00
|
|
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
|
|
return mlirIntegerTypeGet(context, 32);
|
2020-10-02 09:59:58 +08:00
|
|
|
case ScalarType::Long:
|
2020-10-06 14:21:21 +08:00
|
|
|
// TODO: convert to mlirIntegerTypeSignedGet once supported.
|
|
|
|
return mlirIntegerTypeGet(context, 64);
|
2020-10-02 09:59:58 +08:00
|
|
|
case ScalarType::Bool:
|
|
|
|
return npcompBoolTypeGet(context);
|
|
|
|
case ScalarType::Double:
|
|
|
|
return mlirF64TypeGet(context);
|
|
|
|
case ScalarType::Float:
|
|
|
|
return mlirF32TypeGet(context);
|
|
|
|
case ScalarType::BFloat16:
|
|
|
|
return mlirBF16TypeGet(context);
|
|
|
|
case ScalarType::Half:
|
|
|
|
return mlirF16TypeGet(context);
|
2020-11-21 09:03:23 +08:00
|
|
|
default: {
|
|
|
|
return {nullptr};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|
|
|
const c10::TypePtr &torchType) {
|
|
|
|
using c10::TypeKind;
|
|
|
|
auto kind = torchType->kind();
|
|
|
|
switch (kind) {
|
|
|
|
case TypeKind::TensorType: {
|
|
|
|
auto tensorType = torchType->cast<c10::TensorType>();
|
|
|
|
// Element type.
|
|
|
|
MlirType elementType;
|
|
|
|
if (tensorType->scalarType()) {
|
|
|
|
elementType = mapFromTorchScalarType(loc, *tensorType->scalarType());
|
|
|
|
if (mlirTypeIsNull(elementType))
|
|
|
|
return {nullptr};
|
|
|
|
} else {
|
|
|
|
elementType = npcompAnyDtypeTypeGet(context);
|
|
|
|
}
|
|
|
|
// Sizes.
|
|
|
|
auto &sizes = tensorType->symbolic_sizes();
|
|
|
|
if (!sizes.rank()) {
|
|
|
|
// Unranked.
|
|
|
|
return npcompNdArrayTypeGetUnranked(elementType);
|
|
|
|
}
|
|
|
|
// Ranked with possibly dynamic dims.
|
|
|
|
auto &symbolicShape = tensorType->symbolic_sizes();
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<int64_t> dims;
|
2020-11-21 09:03:23 +08:00
|
|
|
dims.resize(*sizes.rank());
|
|
|
|
for (size_t i = 0; i < dims.size(); ++i) {
|
|
|
|
auto shapeSymbol = symbolicShape[i];
|
|
|
|
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
|
|
|
}
|
|
|
|
return npcompNdArrayTypeGetRanked(dims.size(), dims.data(), elementType);
|
|
|
|
}
|
2020-10-02 09:59:58 +08:00
|
|
|
default: {
|
|
|
|
std::stringstream message;
|
2020-11-21 09:03:23 +08:00
|
|
|
message << "unable to map Torch type " << torchType << " to MLIR type";
|
|
|
|
mlirEmitError(loc, message.str().c_str());
|
|
|
|
return {nullptr};
|
2020-10-02 09:59:58 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
2020-10-31 13:52:46 +08:00
|
|
|
if (!tensor.defined()) {
|
|
|
|
// Undefined tensors are equivalent to None.
|
|
|
|
// This may need to be re-evaluated at some point.
|
|
|
|
return npcompNoneTypeGet(context);
|
|
|
|
}
|
2020-10-02 09:59:58 +08:00
|
|
|
|
2020-11-21 09:03:23 +08:00
|
|
|
MlirType elementType = mapFromTorchScalarType(tensor.scalar_type());
|
2020-10-02 09:59:58 +08:00
|
|
|
// TODO: Decide when it is necessary to take strides into account. Right now,
|
|
|
|
// just erase them and let the compiler decide.
|
|
|
|
|
|
|
|
auto sizes = tensor.sizes();
|
|
|
|
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::unique_ptr<FuncBuilder>
|
2020-11-21 09:03:23 +08:00
|
|
|
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
2020-12-15 00:42:42 +08:00
|
|
|
MlirLocation location, const std::string &name,
|
|
|
|
std::vector<MlirType> &inputTypes) {
|
2020-11-21 09:03:23 +08:00
|
|
|
auto context = mlirLocationGetContext(location);
|
2020-10-02 09:59:58 +08:00
|
|
|
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
|
|
|
// (this is fragile and reveals details that are not guaranteed).
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirNamedAttribute> funcAttrs;
|
2020-11-24 11:20:26 +08:00
|
|
|
funcAttrs.push_back(
|
|
|
|
mlirNamedAttributeGet(toMlirStringRef("type"),
|
|
|
|
mlirTypeAttrGet(mlirFunctionTypeGet(
|
|
|
|
context, inputTypes.size(), inputTypes.data(),
|
|
|
|
/*numResults=*/0, /*results=*/nullptr))));
|
2020-10-02 09:59:58 +08:00
|
|
|
funcAttrs.push_back(mlirNamedAttributeGet(
|
2020-11-24 11:20:26 +08:00
|
|
|
toMlirStringRef("sym_name"),
|
2020-12-12 06:43:38 +08:00
|
|
|
mlirStringAttrGet(context,
|
|
|
|
mlirStringRefCreate(name.data(), name.size()))));
|
2020-10-02 09:59:58 +08:00
|
|
|
|
2020-11-24 11:20:26 +08:00
|
|
|
MlirOperationState state =
|
|
|
|
mlirOperationStateGet(toMlirStringRef("func"), location);
|
2020-10-02 09:59:58 +08:00
|
|
|
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
|
|
|
{
|
|
|
|
// Don't access these once ownership transferred.
|
|
|
|
MlirRegion newBodyRegion = mlirRegionCreate();
|
|
|
|
MlirBlock newEntryBlock =
|
|
|
|
mlirBlockCreate(inputTypes.size(), inputTypes.data());
|
|
|
|
mlirRegionInsertOwnedBlockAfter(newBodyRegion, {nullptr}, newEntryBlock);
|
|
|
|
mlirOperationStateAddOwnedRegions(&state, 1, &newBodyRegion);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Need to re-lookup the region/block because we relinquished ownership above.
|
|
|
|
MlirOperation funcOp = mlirOperationCreate(&state);
|
|
|
|
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
|
|
|
|
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
|
|
|
|
|
2020-11-21 09:03:23 +08:00
|
|
|
inserter(funcOp);
|
2020-10-02 09:59:58 +08:00
|
|
|
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
|
2020-10-06 14:21:21 +08:00
|
|
|
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
|
|
|
|
}
|
|
|
|
|
2020-12-15 00:42:42 +08:00
|
|
|
void FuncBuilder::rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes) {
|
2020-10-06 14:21:21 +08:00
|
|
|
// Get inputs from current function type.
|
2020-11-24 11:20:26 +08:00
|
|
|
MlirAttribute funcTypeAttr =
|
|
|
|
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
2020-10-06 14:21:21 +08:00
|
|
|
assert(!mlirAttributeIsNull(funcTypeAttr) &&
|
|
|
|
"function missing 'type' attribute");
|
|
|
|
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
|
|
|
"function type is not a TypeAttr");
|
|
|
|
MlirType funcType = mlirTypeAttrGetValue(funcTypeAttr);
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirType> inputTypes;
|
2020-10-06 14:21:21 +08:00
|
|
|
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(funcType); i < e; ++i) {
|
|
|
|
inputTypes.push_back(mlirFunctionTypeGetInput(funcType, i));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Make new function type.
|
|
|
|
MlirType newFuncType =
|
|
|
|
mlirFunctionTypeGet(context, inputTypes.size(), inputTypes.data(),
|
|
|
|
resultTypes.size(), resultTypes.data());
|
|
|
|
MlirAttribute newFuncTypeAttr = mlirTypeAttrGet(newFuncType);
|
2020-11-24 11:20:26 +08:00
|
|
|
mlirOperationSetAttributeByName(funcOp, toMlirStringRef("type"),
|
|
|
|
newFuncTypeAttr);
|
2020-10-06 14:21:21 +08:00
|
|
|
(void)newFuncTypeAttr;
|
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue FuncBuilder::insertConstantOp(MlirOperation op) {
|
|
|
|
mlirBlockInsertOwnedOperationAfter(entryBlock.getBlock(), prevConstantOp, op);
|
|
|
|
prevConstantOp = op;
|
|
|
|
return mlirOperationGetResult(op, 0);
|
2020-10-02 09:59:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
MlirValue FuncBuilder::lookupTensor(at::Tensor tensor) {
|
|
|
|
for (auto it = tensorValueMap.rbegin(), e = tensorValueMap.rend(); it != e;
|
|
|
|
++it) {
|
|
|
|
if (it->first.is_same(tensor))
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
return {nullptr};
|
|
|
|
}
|
2020-10-06 14:21:21 +08:00
|
|
|
|
|
|
|
MlirValue FuncBuilder::getScalarConstant(MlirLocation loc, at::Scalar s) {
|
|
|
|
// Note that interpreter "scalars" match the Python semantics and are
|
|
|
|
// represented as one of double or int64_t, with a special tag for whether
|
|
|
|
// it should be interpreted as a bool.
|
|
|
|
if (s.isIntegral(/*includeBool=*/false)) {
|
|
|
|
// TODO: Switch to a basicpy.constant that works properly with signed
|
|
|
|
// integers and then switch this to a signed integer.
|
|
|
|
MlirType t = mlirIntegerTypeGet(context, 64);
|
2020-10-16 09:28:30 +08:00
|
|
|
MlirAttribute value = mlirIntegerAttrGet(t, s.to<int64_t>());
|
|
|
|
return getGeneralConstant(loc, value);
|
2020-10-06 14:21:21 +08:00
|
|
|
}
|
|
|
|
if (s.isFloatingPoint()) {
|
|
|
|
MlirType t = mlirF64TypeGet(context);
|
2020-10-16 09:28:30 +08:00
|
|
|
MlirAttribute value = mlirFloatAttrDoubleGet(context, t, s.to<double>());
|
|
|
|
return getGeneralConstant(loc, value);
|
|
|
|
}
|
|
|
|
if (s.isBoolean()) {
|
|
|
|
return getBoolConstant(loc, s.to<bool>());
|
2020-10-06 14:21:21 +08:00
|
|
|
}
|
|
|
|
// TODO: s.isComplex()
|
|
|
|
|
|
|
|
throw std::invalid_argument("TODO: Scalar of unknown kind");
|
|
|
|
}
|
2020-10-16 09:28:30 +08:00
|
|
|
|
|
|
|
MlirValue FuncBuilder::getBoolConstant(MlirLocation loc, bool v) {
|
|
|
|
MlirAttribute value = mlirBoolAttrGet(context, v);
|
|
|
|
return getGeneralConstant(loc, value);
|
|
|
|
}
|
|
|
|
|
2020-10-17 08:38:07 +08:00
|
|
|
MlirValue FuncBuilder::getNoneConstant(MlirLocation loc) {
|
|
|
|
OperationStateHolder state{"basicpy.singleton", loc};
|
|
|
|
MlirType noneType = npcompNoneTypeGet(context);
|
|
|
|
mlirOperationStateAddResults(state, 1, &noneType);
|
|
|
|
MlirOperation op = state.createOperation();
|
|
|
|
return insertConstantOp(op);
|
|
|
|
}
|
|
|
|
|
2020-10-16 09:28:30 +08:00
|
|
|
MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|
|
|
MlirAttribute value) {
|
|
|
|
MlirType valueType = mlirAttributeGetType(value);
|
|
|
|
MlirOperation constOp = createStandardConstant(loc, valueType, value);
|
|
|
|
MlirValue constValue = insertConstantOp(constOp);
|
|
|
|
return constValue;
|
|
|
|
}
|
2020-10-17 08:38:07 +08:00
|
|
|
|
Add a number of kernels and new patterns.
* convolution, convolution_backward, _log_softmax, _log_softmax_backward_data, nll_loss_forward, nll_loss_backward, nll_loss2d_forward, nll_loss2d_backward, copy_
* Extends the recognition logic and metadata for handling inplace transformations, optional tensors, ints, lists and dropped args.
* The kernel_calls generated by test_conv_nllloss_grads.py now convert to ATen.
* The result *almost* comes out as a pure tensor program with the exception of the copy_ op, which I will do some followup work to deal with.
* More progress on #97
2020-11-04 11:24:28 +08:00
|
|
|
MlirValue FuncBuilder::buildList(MlirLocation loc,
|
2020-12-15 00:42:42 +08:00
|
|
|
std::vector<MlirValue> &elements) {
|
2020-10-17 08:38:07 +08:00
|
|
|
MlirType resultType = npcompListTypeGet(context);
|
|
|
|
OperationStateHolder state{"basicpy.build_list", loc};
|
|
|
|
mlirOperationStateAddResults(state, 1, &resultType);
|
|
|
|
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
|
|
|
MlirOperation op = state.createOperation();
|
2020-11-03 07:30:21 +08:00
|
|
|
entryBlock.insertBeforeTerminator(op);
|
|
|
|
return mlirOperationGetResult(op, 0);
|
2020-10-17 08:38:07 +08:00
|
|
|
}
|