diff --git a/include/npcomp/Dialect/Torch/IR/TorchTypes.h b/include/npcomp/Dialect/Torch/IR/TorchTypes.h index 55a18c351..77a17bbf6 100644 --- a/include/npcomp/Dialect/Torch/IR/TorchTypes.h +++ b/include/npcomp/Dialect/Torch/IR/TorchTypes.h @@ -9,7 +9,7 @@ #ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H #define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H -#include "mlir/IR/Types.h" +#include "mlir/IR/BuiltinTypes.h" namespace mlir { namespace NPCOMP { diff --git a/lib/Dialect/Torch/IR/CMakeLists.txt b/lib/Dialect/Torch/IR/CMakeLists.txt index 5da554f3f..974771915 100644 --- a/lib/Dialect/Torch/IR/CMakeLists.txt +++ b/lib/Dialect/Torch/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_npcomp_dialect_library(NPCOMPTorchDialect TorchDialect.cpp TorchOps.cpp + TorchTypes.cpp TorchUtils.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 993bda6dd..709039af7 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -47,6 +47,10 @@ struct TorchInlinerInterface : public DialectInlinerInterface { #define GET_TYPEDEF_CLASSES #include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc" +//===----------------------------------------------------------------------===// +// Dialect initialize method. +//===----------------------------------------------------------------------===// + void TorchDialect::initialize() { addOperations< #define GET_OP_LIST @@ -61,6 +65,10 @@ void TorchDialect::initialize() { getContext()->loadDialect(); } +//===----------------------------------------------------------------------===// +// Type-related Dialect methods. +//===----------------------------------------------------------------------===// + Type TorchDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; if (parser.parseKeyword(&keyword)) @@ -79,261 +87,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const { llvm_unreachable("unknown 'torch' type"); } -//===----------------------------------------------------------------------===// -// BaseTensorType -//===----------------------------------------------------------------------===// -// TODO: Move most of this to a new file TorchTypes.cpp. - -static bool isValidTorchDtype(Type dtype) { - // Torch quantized types. - if (dtype.isa()) - return true; - // Builtin floating point types. - if (dtype.isa()) - return true; - // Builtin integer types. - if (IntegerType type = dtype.dyn_cast()) { - if (type.isSignless() && type.getWidth() == 1) - return true; - if (type.isSigned()) { - for (unsigned width : {8, 16, 32, 64}) { - if (type.getWidth() == width) - return true; - } - } - if (type.isUnsigned()) { - return type.getWidth() == 8; - } - } - return false; -} - -bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const { - return getOptionalSizes() == other.getOptionalSizes() && - getOptionalDtype() == other.getOptionalDtype(); -} - -Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const { - return getWithSizesAndDtype(other.getOptionalSizes(), - other.getOptionalDtype()); -} - -Type BaseTensorType::getWithSizesAndDtype( - Optional> optionalSizes, Type optionalDtype) const { - if (isa()) - return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype); - if (isa()) - return ValueTensorType::get(getContext(), optionalSizes, optionalDtype); - llvm_unreachable("not a BaseTensorType!"); -} - -ValueTensorType BaseTensorType::getWithValueSemantics() const { - if (auto tensor = dyn_cast()) - return tensor.getWithValueSemantics(); - if (auto tensor = dyn_cast()) - return tensor; - llvm_unreachable("not a BaseTensorType!"); -} - -static LogicalResult -verifyTensorType(function_ref emitError, - Optional> optionalSizes, - Type optionalDtype) { - if (optionalDtype && !isValidTorchDtype(optionalDtype)) { - emitError() << "invalid dtype " << optionalDtype - << " for !torch.tensor type"; - return failure(); - } - return success(); -} - -Type parseTensorType(MLIRContext *context, DialectAsmParser &parser, - GetTensorTypeFn getTensorType) { - llvm::SMLoc startLoc = parser.getCurrentLocation(); - if (parser.parseOptionalLess()) - return getTensorType(context, - /*optionalSizes=*/None, /*optionalDtype=*/Type()); - bool hasSizes; - SmallVector sizes; - if (succeeded(parser.parseOptionalStar())) { - // Unranked. - hasSizes = false; - } else { - // Parse list of sizes. - hasSizes = true; - if (parser.parseLSquare()) - return Type(); - for (bool first = true;; first = false) { - if (!first) { - if (failed(parser.parseOptionalComma())) { - break; - } - } - if (succeeded(parser.parseOptionalQuestion())) { - sizes.push_back(-1); - continue; - } - - int64_t size; - auto optionalInt = parser.parseOptionalInteger(size); - if (optionalInt.hasValue()) { - if (failed(*optionalInt)) - return Type(); - sizes.push_back(size); - continue; - } - break; - } - if (parser.parseRSquare()) { - return Type(); - } - } - if (parser.parseComma()) - return Type(); - Type optionalDtype; - if (succeeded(parser.parseOptionalKeyword("unk"))) { - // Unknown dtype. - } else { - // Known dtype. - if (parser.parseType(optionalDtype)) - return Type(); - } - if (parser.parseGreater()) - return Type(); - Optional> optionalSizes; - if (hasSizes) - optionalSizes.emplace(sizes); - - if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); }, - optionalSizes, optionalDtype))) - return Type(); - - return getTensorType(context, optionalSizes, optionalDtype); -} - -static void printTensorType(DialectAsmPrinter &printer, - Optional> optionalSizes, - Type optionalDtype) { - if (!optionalSizes && !optionalDtype) - return; - printer << "<"; - if (optionalSizes) { - printer << "["; - for (auto it : llvm::enumerate(*optionalSizes)) { - if (it.index() > 0) - printer << ","; - if (it.value() < 0) - printer << "?"; - else - printer << it.value(); - } - printer << "]"; - } else { - printer << "*"; - } - printer << ","; - if (optionalDtype) - printer.printType(optionalDtype); - else - printer << "unk"; - printer << ">"; -} - -//===----------------------------------------------------------------------===// -// NonValueTensorType -//===----------------------------------------------------------------------===// - -ValueTensorType NonValueTensorType::getWithValueSemantics() const { - return ValueTensorType::get(getContext(), getOptionalSizes(), - getOptionalDtype()); -} - -NonValueTensorType -NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { - return NonValueTensorType::get(context, - /*optionalSizes=*/None, - /*optionalDtype=*/Type()); -} - -NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) { - return NonValueTensorType::get(type.getContext(), - type.hasRank() ? type.getShape() - : Optional>(), - type.getElementType()); -} - -LogicalResult -NonValueTensorType::verify(function_ref emitError, - Optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); -} - -Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) { - return parseTensorType( - context, parser, - [](MLIRContext *context, Optional> optionalSizes, - Type optionalType) { - return NonValueTensorType::get(context, optionalSizes, optionalType); - }); -} - -void NonValueTensorType::print(DialectAsmPrinter &printer) const { - printer << "tensor"; - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); -} - -//===----------------------------------------------------------------------===// -// ValueTensorType -//===----------------------------------------------------------------------===// - -NonValueTensorType ValueTensorType::getWithoutValueSemantics() const { - return NonValueTensorType::get(getContext(), getOptionalSizes(), - getOptionalDtype()); -} - -ValueTensorType -ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { - return ValueTensorType::get(context, - /*optionalSizes=*/None, - /*optionalDtype=*/Type()); -} - -ValueTensorType ValueTensorType::getFromShaped(ShapedType type) { - return ValueTensorType::get(type.getContext(), - type.hasRank() ? type.getShape() - : Optional>(), - type.getElementType()); -} - -TensorType ValueTensorType::toBuiltinTensor() const { - if (!hasDtype()) - return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); - return RankedTensorType::get(getSizes(), getDtype()); -} - -LogicalResult -ValueTensorType::verify(function_ref emitError, - Optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); -} - -Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) { - return parseTensorType( - context, parser, - [](MLIRContext *context, Optional> optionalSizes, - Type optionalType) { - return ValueTensorType::get(context, optionalSizes, optionalType); - }); -} - -void ValueTensorType::print(DialectAsmPrinter &printer) const { - printer << "vtensor"; - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); -} //===----------------------------------------------------------------------===// // Dialect-level verifiers. diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp new file mode 100644 index 000000000..04dcda947 --- /dev/null +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -0,0 +1,271 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "npcomp/Dialect/Torch/IR/TorchDialect.h" +#include "npcomp/Dialect/Torch/IR/TorchOps.h" + +using namespace mlir; +using namespace mlir::NPCOMP; +using namespace mlir::NPCOMP::Torch; + +//===----------------------------------------------------------------------===// +// BaseTensorType +//===----------------------------------------------------------------------===// + +static bool isValidTorchDtype(Type dtype) { + // Torch quantized types. + if (dtype.isa()) + return true; + // Builtin floating point types. + if (dtype.isa()) + return true; + // Builtin integer types. + if (IntegerType type = dtype.dyn_cast()) { + if (type.isSignless() && type.getWidth() == 1) + return true; + if (type.isSigned()) { + for (unsigned width : {8, 16, 32, 64}) { + if (type.getWidth() == width) + return true; + } + } + if (type.isUnsigned()) { + return type.getWidth() == 8; + } + } + return false; +} + +bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const { + return getOptionalSizes() == other.getOptionalSizes() && + getOptionalDtype() == other.getOptionalDtype(); +} + +Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const { + return getWithSizesAndDtype(other.getOptionalSizes(), + other.getOptionalDtype()); +} + +Type BaseTensorType::getWithSizesAndDtype( + Optional> optionalSizes, Type optionalDtype) const { + if (isa()) + return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype); + if (isa()) + return ValueTensorType::get(getContext(), optionalSizes, optionalDtype); + llvm_unreachable("not a BaseTensorType!"); +} + +ValueTensorType BaseTensorType::getWithValueSemantics() const { + if (auto tensor = dyn_cast()) + return tensor.getWithValueSemantics(); + if (auto tensor = dyn_cast()) + return tensor; + llvm_unreachable("not a BaseTensorType!"); +} + +static LogicalResult +verifyTensorType(function_ref emitError, + Optional> optionalSizes, + Type optionalDtype) { + if (optionalDtype && !isValidTorchDtype(optionalDtype)) { + emitError() << "invalid dtype " << optionalDtype + << " for !torch.tensor type"; + return failure(); + } + return success(); +} + +Type parseTensorType(MLIRContext *context, DialectAsmParser &parser, + GetTensorTypeFn getTensorType) { + llvm::SMLoc startLoc = parser.getCurrentLocation(); + if (parser.parseOptionalLess()) + return getTensorType(context, + /*optionalSizes=*/None, /*optionalDtype=*/Type()); + bool hasSizes; + SmallVector sizes; + if (succeeded(parser.parseOptionalStar())) { + // Unranked. + hasSizes = false; + } else { + // Parse list of sizes. + hasSizes = true; + if (parser.parseLSquare()) + return Type(); + for (bool first = true;; first = false) { + if (!first) { + if (failed(parser.parseOptionalComma())) { + break; + } + } + if (succeeded(parser.parseOptionalQuestion())) { + sizes.push_back(-1); + continue; + } + + int64_t size; + auto optionalInt = parser.parseOptionalInteger(size); + if (optionalInt.hasValue()) { + if (failed(*optionalInt)) + return Type(); + sizes.push_back(size); + continue; + } + break; + } + if (parser.parseRSquare()) { + return Type(); + } + } + if (parser.parseComma()) + return Type(); + Type optionalDtype; + if (succeeded(parser.parseOptionalKeyword("unk"))) { + // Unknown dtype. + } else { + // Known dtype. + if (parser.parseType(optionalDtype)) + return Type(); + } + if (parser.parseGreater()) + return Type(); + Optional> optionalSizes; + if (hasSizes) + optionalSizes.emplace(sizes); + + if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); }, + optionalSizes, optionalDtype))) + return Type(); + + return getTensorType(context, optionalSizes, optionalDtype); +} + +static void printTensorType(DialectAsmPrinter &printer, + Optional> optionalSizes, + Type optionalDtype) { + if (!optionalSizes && !optionalDtype) + return; + printer << "<"; + if (optionalSizes) { + printer << "["; + for (auto it : llvm::enumerate(*optionalSizes)) { + if (it.index() > 0) + printer << ","; + if (it.value() < 0) + printer << "?"; + else + printer << it.value(); + } + printer << "]"; + } else { + printer << "*"; + } + printer << ","; + if (optionalDtype) + printer.printType(optionalDtype); + else + printer << "unk"; + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// NonValueTensorType +//===----------------------------------------------------------------------===// + +ValueTensorType NonValueTensorType::getWithValueSemantics() const { + return ValueTensorType::get(getContext(), getOptionalSizes(), + getOptionalDtype()); +} + +NonValueTensorType +NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { + return NonValueTensorType::get(context, + /*optionalSizes=*/None, + /*optionalDtype=*/Type()); +} + +NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) { + return NonValueTensorType::get(type.getContext(), + type.hasRank() ? type.getShape() + : Optional>(), + type.getElementType()); +} + +LogicalResult +NonValueTensorType::verify(function_ref emitError, + Optional> optionalSizes, + Type optionalDtype) { + return verifyTensorType(emitError, optionalSizes, optionalDtype); +} + +Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) { + return parseTensorType( + context, parser, + [](MLIRContext *context, Optional> optionalSizes, + Type optionalType) { + return NonValueTensorType::get(context, optionalSizes, optionalType); + }); +} + +void NonValueTensorType::print(DialectAsmPrinter &printer) const { + printer << "tensor"; + printTensorType(printer, getOptionalSizes(), getOptionalDtype()); +} + +//===----------------------------------------------------------------------===// +// ValueTensorType +//===----------------------------------------------------------------------===// + +NonValueTensorType ValueTensorType::getWithoutValueSemantics() const { + return NonValueTensorType::get(getContext(), getOptionalSizes(), + getOptionalDtype()); +} + +ValueTensorType +ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { + return ValueTensorType::get(context, + /*optionalSizes=*/None, + /*optionalDtype=*/Type()); +} + +ValueTensorType ValueTensorType::getFromShaped(ShapedType type) { + return ValueTensorType::get(type.getContext(), + type.hasRank() ? type.getShape() + : Optional>(), + type.getElementType()); +} + +TensorType ValueTensorType::toBuiltinTensor() const { + if (!hasDtype()) + return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(getDtype()); + return RankedTensorType::get(getSizes(), getDtype()); +} + +LogicalResult +ValueTensorType::verify(function_ref emitError, + Optional> optionalSizes, + Type optionalDtype) { + return verifyTensorType(emitError, optionalSizes, optionalDtype); +} + +Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) { + return parseTensorType( + context, parser, + [](MLIRContext *context, Optional> optionalSizes, + Type optionalType) { + return ValueTensorType::get(context, optionalSizes, optionalType); + }); +} + +void ValueTensorType::print(DialectAsmPrinter &printer) const { + printer << "vtensor"; + printTensorType(printer, getOptionalSizes(), getOptionalDtype()); +}