Move Torch type implementation code into TorchTypes.cpp

pull/219/head
Sean Silva 2021-06-10 16:46:20 -07:00
parent 0b6516c7cc
commit 81bcd7fb12
4 changed files with 281 additions and 256 deletions

View File

@ -9,7 +9,7 @@
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
#include "mlir/IR/Types.h" #include "mlir/IR/BuiltinTypes.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {

View File

@ -1,6 +1,7 @@
add_npcomp_dialect_library(NPCOMPTorchDialect add_npcomp_dialect_library(NPCOMPTorchDialect
TorchDialect.cpp TorchDialect.cpp
TorchOps.cpp TorchOps.cpp
TorchTypes.cpp
TorchUtils.cpp TorchUtils.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS

View File

@ -47,6 +47,10 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
#define GET_TYPEDEF_CLASSES #define GET_TYPEDEF_CLASSES
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc" #include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
//===----------------------------------------------------------------------===//
// Dialect initialize method.
//===----------------------------------------------------------------------===//
void TorchDialect::initialize() { void TorchDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
@ -61,6 +65,10 @@ void TorchDialect::initialize() {
getContext()->loadDialect<Basicpy::BasicpyDialect>(); getContext()->loadDialect<Basicpy::BasicpyDialect>();
} }
//===----------------------------------------------------------------------===//
// Type-related Dialect methods.
//===----------------------------------------------------------------------===//
Type TorchDialect::parseType(DialectAsmParser &parser) const { Type TorchDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword; StringRef keyword;
if (parser.parseKeyword(&keyword)) if (parser.parseKeyword(&keyword))
@ -79,261 +87,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
llvm_unreachable("unknown 'torch' type"); llvm_unreachable("unknown 'torch' type");
} }
//===----------------------------------------------------------------------===//
// BaseTensorType
//===----------------------------------------------------------------------===//
// TODO: Move most of this to a new file TorchTypes.cpp.
static bool isValidTorchDtype(Type dtype) {
// Torch quantized types.
if (dtype.isa<Torch::QInt8Type>())
return true;
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
return true;
// Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (type.getWidth() == width)
return true;
}
}
if (type.isUnsigned()) {
return type.getWidth() == 8;
}
}
return false;
}
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
}
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
return getWithSizesAndDtype(other.getOptionalSizes(),
other.getOptionalDtype());
}
Type BaseTensorType::getWithSizesAndDtype(
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
if (isa<NonValueTensorType>())
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
if (isa<ValueTensorType>())
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
llvm_unreachable("not a BaseTensorType!");
}
ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics();
if (auto tensor = dyn_cast<ValueTensorType>())
return tensor;
llvm_unreachable("not a BaseTensorType!");
}
static LogicalResult
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
emitError() << "invalid dtype " << optionalDtype
<< " for !torch.tensor type";
return failure();
}
return success();
}
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
GetTensorTypeFn getTensorType) {
llvm::SMLoc startLoc = parser.getCurrentLocation();
if (parser.parseOptionalLess())
return getTensorType(context,
/*optionalSizes=*/None, /*optionalDtype=*/Type());
bool hasSizes;
SmallVector<int64_t> sizes;
if (succeeded(parser.parseOptionalStar())) {
// Unranked.
hasSizes = false;
} else {
// Parse list of sizes.
hasSizes = true;
if (parser.parseLSquare())
return Type();
for (bool first = true;; first = false) {
if (!first) {
if (failed(parser.parseOptionalComma())) {
break;
}
}
if (succeeded(parser.parseOptionalQuestion())) {
sizes.push_back(-1);
continue;
}
int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
if (failed(*optionalInt))
return Type();
sizes.push_back(size);
continue;
}
break;
}
if (parser.parseRSquare()) {
return Type();
}
}
if (parser.parseComma())
return Type();
Type optionalDtype;
if (succeeded(parser.parseOptionalKeyword("unk"))) {
// Unknown dtype.
} else {
// Known dtype.
if (parser.parseType(optionalDtype))
return Type();
}
if (parser.parseGreater())
return Type();
Optional<ArrayRef<int64_t>> optionalSizes;
if (hasSizes)
optionalSizes.emplace(sizes);
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
optionalSizes, optionalDtype)))
return Type();
return getTensorType(context, optionalSizes, optionalDtype);
}
static void printTensorType(DialectAsmPrinter &printer,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (!optionalSizes && !optionalDtype)
return;
printer << "<";
if (optionalSizes) {
printer << "[";
for (auto it : llvm::enumerate(*optionalSizes)) {
if (it.index() > 0)
printer << ",";
if (it.value() < 0)
printer << "?";
else
printer << it.value();
}
printer << "]";
} else {
printer << "*";
}
printer << ",";
if (optionalDtype)
printer.printType(optionalDtype);
else
printer << "unk";
printer << ">";
}
//===----------------------------------------------------------------------===//
// NonValueTensorType
//===----------------------------------------------------------------------===//
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
return ValueTensorType::get(getContext(), getOptionalSizes(),
getOptionalDtype());
}
NonValueTensorType
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return NonValueTensorType::get(context,
/*optionalSizes=*/None,
/*optionalDtype=*/Type());
}
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
return NonValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
}
LogicalResult
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
}
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return NonValueTensorType::get(context, optionalSizes, optionalType);
});
}
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "tensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
}
//===----------------------------------------------------------------------===//
// ValueTensorType
//===----------------------------------------------------------------------===//
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
return NonValueTensorType::get(getContext(), getOptionalSizes(),
getOptionalDtype());
}
ValueTensorType
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return ValueTensorType::get(context,
/*optionalSizes=*/None,
/*optionalDtype=*/Type());
}
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
return ValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
}
TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
return RankedTensorType::get(getSizes(), getDtype());
}
LogicalResult
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
}
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return ValueTensorType::get(context, optionalSizes, optionalType);
});
}
void ValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "vtensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect-level verifiers. // Dialect-level verifiers.

View File

@ -0,0 +1,271 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
//===----------------------------------------------------------------------===//
// BaseTensorType
//===----------------------------------------------------------------------===//
static bool isValidTorchDtype(Type dtype) {
// Torch quantized types.
if (dtype.isa<Torch::QInt8Type>())
return true;
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
return true;
// Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (type.getWidth() == width)
return true;
}
}
if (type.isUnsigned()) {
return type.getWidth() == 8;
}
}
return false;
}
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
}
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
return getWithSizesAndDtype(other.getOptionalSizes(),
other.getOptionalDtype());
}
Type BaseTensorType::getWithSizesAndDtype(
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
if (isa<NonValueTensorType>())
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
if (isa<ValueTensorType>())
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
llvm_unreachable("not a BaseTensorType!");
}
ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics();
if (auto tensor = dyn_cast<ValueTensorType>())
return tensor;
llvm_unreachable("not a BaseTensorType!");
}
static LogicalResult
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
emitError() << "invalid dtype " << optionalDtype
<< " for !torch.tensor type";
return failure();
}
return success();
}
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
GetTensorTypeFn getTensorType) {
llvm::SMLoc startLoc = parser.getCurrentLocation();
if (parser.parseOptionalLess())
return getTensorType(context,
/*optionalSizes=*/None, /*optionalDtype=*/Type());
bool hasSizes;
SmallVector<int64_t> sizes;
if (succeeded(parser.parseOptionalStar())) {
// Unranked.
hasSizes = false;
} else {
// Parse list of sizes.
hasSizes = true;
if (parser.parseLSquare())
return Type();
for (bool first = true;; first = false) {
if (!first) {
if (failed(parser.parseOptionalComma())) {
break;
}
}
if (succeeded(parser.parseOptionalQuestion())) {
sizes.push_back(-1);
continue;
}
int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
if (failed(*optionalInt))
return Type();
sizes.push_back(size);
continue;
}
break;
}
if (parser.parseRSquare()) {
return Type();
}
}
if (parser.parseComma())
return Type();
Type optionalDtype;
if (succeeded(parser.parseOptionalKeyword("unk"))) {
// Unknown dtype.
} else {
// Known dtype.
if (parser.parseType(optionalDtype))
return Type();
}
if (parser.parseGreater())
return Type();
Optional<ArrayRef<int64_t>> optionalSizes;
if (hasSizes)
optionalSizes.emplace(sizes);
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
optionalSizes, optionalDtype)))
return Type();
return getTensorType(context, optionalSizes, optionalDtype);
}
static void printTensorType(DialectAsmPrinter &printer,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (!optionalSizes && !optionalDtype)
return;
printer << "<";
if (optionalSizes) {
printer << "[";
for (auto it : llvm::enumerate(*optionalSizes)) {
if (it.index() > 0)
printer << ",";
if (it.value() < 0)
printer << "?";
else
printer << it.value();
}
printer << "]";
} else {
printer << "*";
}
printer << ",";
if (optionalDtype)
printer.printType(optionalDtype);
else
printer << "unk";
printer << ">";
}
//===----------------------------------------------------------------------===//
// NonValueTensorType
//===----------------------------------------------------------------------===//
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
return ValueTensorType::get(getContext(), getOptionalSizes(),
getOptionalDtype());
}
NonValueTensorType
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return NonValueTensorType::get(context,
/*optionalSizes=*/None,
/*optionalDtype=*/Type());
}
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
return NonValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
}
LogicalResult
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
}
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return NonValueTensorType::get(context, optionalSizes, optionalType);
});
}
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "tensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
}
//===----------------------------------------------------------------------===//
// ValueTensorType
//===----------------------------------------------------------------------===//
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
return NonValueTensorType::get(getContext(), getOptionalSizes(),
getOptionalDtype());
}
ValueTensorType
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return ValueTensorType::get(context,
/*optionalSizes=*/None,
/*optionalDtype=*/Type());
}
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
return ValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
}
TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
return RankedTensorType::get(getSizes(), getDtype());
}
LogicalResult
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
}
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return ValueTensorType::get(context, optionalSizes, optionalType);
});
}
void ValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "vtensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
}