mirror of https://github.com/llvm/torch-mlir
Move Torch type implementation code into TorchTypes.cpp
parent
0b6516c7cc
commit
81bcd7fb12
|
@ -9,7 +9,7 @@
|
||||||
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
#ifndef NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
#define NPCOMP_DIALECT_TORCH_IR_TORCHTYPES_H
|
||||||
|
|
||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
add_npcomp_dialect_library(NPCOMPTorchDialect
|
add_npcomp_dialect_library(NPCOMPTorchDialect
|
||||||
TorchDialect.cpp
|
TorchDialect.cpp
|
||||||
TorchOps.cpp
|
TorchOps.cpp
|
||||||
|
TorchTypes.cpp
|
||||||
TorchUtils.cpp
|
TorchUtils.cpp
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
|
|
@ -47,6 +47,10 @@ struct TorchInlinerInterface : public DialectInlinerInterface {
|
||||||
#define GET_TYPEDEF_CLASSES
|
#define GET_TYPEDEF_CLASSES
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.cpp.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dialect initialize method.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void TorchDialect::initialize() {
|
void TorchDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
|
@ -61,6 +65,10 @@ void TorchDialect::initialize() {
|
||||||
getContext()->loadDialect<Basicpy::BasicpyDialect>();
|
getContext()->loadDialect<Basicpy::BasicpyDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Type-related Dialect methods.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
||||||
StringRef keyword;
|
StringRef keyword;
|
||||||
if (parser.parseKeyword(&keyword))
|
if (parser.parseKeyword(&keyword))
|
||||||
|
@ -79,261 +87,6 @@ void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||||
llvm_unreachable("unknown 'torch' type");
|
llvm_unreachable("unknown 'torch' type");
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BaseTensorType
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// TODO: Move most of this to a new file TorchTypes.cpp.
|
|
||||||
|
|
||||||
static bool isValidTorchDtype(Type dtype) {
|
|
||||||
// Torch quantized types.
|
|
||||||
if (dtype.isa<Torch::QInt8Type>())
|
|
||||||
return true;
|
|
||||||
// Builtin floating point types.
|
|
||||||
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
|
||||||
return true;
|
|
||||||
// Builtin integer types.
|
|
||||||
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
|
|
||||||
if (type.isSignless() && type.getWidth() == 1)
|
|
||||||
return true;
|
|
||||||
if (type.isSigned()) {
|
|
||||||
for (unsigned width : {8, 16, 32, 64}) {
|
|
||||||
if (type.getWidth() == width)
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (type.isUnsigned()) {
|
|
||||||
return type.getWidth() == 8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
|
|
||||||
return getOptionalSizes() == other.getOptionalSizes() &&
|
|
||||||
getOptionalDtype() == other.getOptionalDtype();
|
|
||||||
}
|
|
||||||
|
|
||||||
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
|
|
||||||
return getWithSizesAndDtype(other.getOptionalSizes(),
|
|
||||||
other.getOptionalDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
Type BaseTensorType::getWithSizesAndDtype(
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
|
||||||
if (isa<NonValueTensorType>())
|
|
||||||
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
|
||||||
if (isa<ValueTensorType>())
|
|
||||||
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
|
||||||
llvm_unreachable("not a BaseTensorType!");
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
|
||||||
if (auto tensor = dyn_cast<NonValueTensorType>())
|
|
||||||
return tensor.getWithValueSemantics();
|
|
||||||
if (auto tensor = dyn_cast<ValueTensorType>())
|
|
||||||
return tensor;
|
|
||||||
llvm_unreachable("not a BaseTensorType!");
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult
|
|
||||||
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalDtype) {
|
|
||||||
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
|
|
||||||
emitError() << "invalid dtype " << optionalDtype
|
|
||||||
<< " for !torch.tensor type";
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
|
|
||||||
GetTensorTypeFn getTensorType) {
|
|
||||||
llvm::SMLoc startLoc = parser.getCurrentLocation();
|
|
||||||
if (parser.parseOptionalLess())
|
|
||||||
return getTensorType(context,
|
|
||||||
/*optionalSizes=*/None, /*optionalDtype=*/Type());
|
|
||||||
bool hasSizes;
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
if (succeeded(parser.parseOptionalStar())) {
|
|
||||||
// Unranked.
|
|
||||||
hasSizes = false;
|
|
||||||
} else {
|
|
||||||
// Parse list of sizes.
|
|
||||||
hasSizes = true;
|
|
||||||
if (parser.parseLSquare())
|
|
||||||
return Type();
|
|
||||||
for (bool first = true;; first = false) {
|
|
||||||
if (!first) {
|
|
||||||
if (failed(parser.parseOptionalComma())) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (succeeded(parser.parseOptionalQuestion())) {
|
|
||||||
sizes.push_back(-1);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t size;
|
|
||||||
auto optionalInt = parser.parseOptionalInteger(size);
|
|
||||||
if (optionalInt.hasValue()) {
|
|
||||||
if (failed(*optionalInt))
|
|
||||||
return Type();
|
|
||||||
sizes.push_back(size);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (parser.parseRSquare()) {
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (parser.parseComma())
|
|
||||||
return Type();
|
|
||||||
Type optionalDtype;
|
|
||||||
if (succeeded(parser.parseOptionalKeyword("unk"))) {
|
|
||||||
// Unknown dtype.
|
|
||||||
} else {
|
|
||||||
// Known dtype.
|
|
||||||
if (parser.parseType(optionalDtype))
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
if (parser.parseGreater())
|
|
||||||
return Type();
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes;
|
|
||||||
if (hasSizes)
|
|
||||||
optionalSizes.emplace(sizes);
|
|
||||||
|
|
||||||
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
|
|
||||||
optionalSizes, optionalDtype)))
|
|
||||||
return Type();
|
|
||||||
|
|
||||||
return getTensorType(context, optionalSizes, optionalDtype);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void printTensorType(DialectAsmPrinter &printer,
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalDtype) {
|
|
||||||
if (!optionalSizes && !optionalDtype)
|
|
||||||
return;
|
|
||||||
printer << "<";
|
|
||||||
if (optionalSizes) {
|
|
||||||
printer << "[";
|
|
||||||
for (auto it : llvm::enumerate(*optionalSizes)) {
|
|
||||||
if (it.index() > 0)
|
|
||||||
printer << ",";
|
|
||||||
if (it.value() < 0)
|
|
||||||
printer << "?";
|
|
||||||
else
|
|
||||||
printer << it.value();
|
|
||||||
}
|
|
||||||
printer << "]";
|
|
||||||
} else {
|
|
||||||
printer << "*";
|
|
||||||
}
|
|
||||||
printer << ",";
|
|
||||||
if (optionalDtype)
|
|
||||||
printer.printType(optionalDtype);
|
|
||||||
else
|
|
||||||
printer << "unk";
|
|
||||||
printer << ">";
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// NonValueTensorType
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
|
|
||||||
return ValueTensorType::get(getContext(), getOptionalSizes(),
|
|
||||||
getOptionalDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
NonValueTensorType
|
|
||||||
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
|
||||||
return NonValueTensorType::get(context,
|
|
||||||
/*optionalSizes=*/None,
|
|
||||||
/*optionalDtype=*/Type());
|
|
||||||
}
|
|
||||||
|
|
||||||
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
|
|
||||||
return NonValueTensorType::get(type.getContext(),
|
|
||||||
type.hasRank() ? type.getShape()
|
|
||||||
: Optional<ArrayRef<int64_t>>(),
|
|
||||||
type.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalDtype) {
|
|
||||||
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
|
||||||
return parseTensorType(
|
|
||||||
context, parser,
|
|
||||||
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalType) {
|
|
||||||
return NonValueTensorType::get(context, optionalSizes, optionalType);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
|
|
||||||
printer << "tensor";
|
|
||||||
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ValueTensorType
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
|
|
||||||
return NonValueTensorType::get(getContext(), getOptionalSizes(),
|
|
||||||
getOptionalDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueTensorType
|
|
||||||
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
|
||||||
return ValueTensorType::get(context,
|
|
||||||
/*optionalSizes=*/None,
|
|
||||||
/*optionalDtype=*/Type());
|
|
||||||
}
|
|
||||||
|
|
||||||
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
|
|
||||||
return ValueTensorType::get(type.getContext(),
|
|
||||||
type.hasRank() ? type.getShape()
|
|
||||||
: Optional<ArrayRef<int64_t>>(),
|
|
||||||
type.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
TensorType ValueTensorType::toBuiltinTensor() const {
|
|
||||||
if (!hasDtype())
|
|
||||||
return nullptr;
|
|
||||||
if (!hasSizes())
|
|
||||||
return UnrankedTensorType::get(getDtype());
|
|
||||||
return RankedTensorType::get(getSizes(), getDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalDtype) {
|
|
||||||
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
|
||||||
}
|
|
||||||
|
|
||||||
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
|
||||||
return parseTensorType(
|
|
||||||
context, parser,
|
|
||||||
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
|
||||||
Type optionalType) {
|
|
||||||
return ValueTensorType::get(context, optionalSizes, optionalType);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
void ValueTensorType::print(DialectAsmPrinter &printer) const {
|
|
||||||
printer << "vtensor";
|
|
||||||
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect-level verifiers.
|
// Dialect-level verifiers.
|
||||||
|
|
|
@ -0,0 +1,271 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
using namespace mlir::NPCOMP::Torch;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BaseTensorType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static bool isValidTorchDtype(Type dtype) {
|
||||||
|
// Torch quantized types.
|
||||||
|
if (dtype.isa<Torch::QInt8Type>())
|
||||||
|
return true;
|
||||||
|
// Builtin floating point types.
|
||||||
|
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
|
||||||
|
return true;
|
||||||
|
// Builtin integer types.
|
||||||
|
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
|
||||||
|
if (type.isSignless() && type.getWidth() == 1)
|
||||||
|
return true;
|
||||||
|
if (type.isSigned()) {
|
||||||
|
for (unsigned width : {8, 16, 32, 64}) {
|
||||||
|
if (type.getWidth() == width)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (type.isUnsigned()) {
|
||||||
|
return type.getWidth() == 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
|
||||||
|
return getOptionalSizes() == other.getOptionalSizes() &&
|
||||||
|
getOptionalDtype() == other.getOptionalDtype();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
|
||||||
|
return getWithSizesAndDtype(other.getOptionalSizes(),
|
||||||
|
other.getOptionalDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
Type BaseTensorType::getWithSizesAndDtype(
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
|
||||||
|
if (isa<NonValueTensorType>())
|
||||||
|
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||||
|
if (isa<ValueTensorType>())
|
||||||
|
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
|
||||||
|
llvm_unreachable("not a BaseTensorType!");
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueTensorType BaseTensorType::getWithValueSemantics() const {
|
||||||
|
if (auto tensor = dyn_cast<NonValueTensorType>())
|
||||||
|
return tensor.getWithValueSemantics();
|
||||||
|
if (auto tensor = dyn_cast<ValueTensorType>())
|
||||||
|
return tensor;
|
||||||
|
llvm_unreachable("not a BaseTensorType!");
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalDtype) {
|
||||||
|
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
|
||||||
|
emitError() << "invalid dtype " << optionalDtype
|
||||||
|
<< " for !torch.tensor type";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
|
||||||
|
GetTensorTypeFn getTensorType) {
|
||||||
|
llvm::SMLoc startLoc = parser.getCurrentLocation();
|
||||||
|
if (parser.parseOptionalLess())
|
||||||
|
return getTensorType(context,
|
||||||
|
/*optionalSizes=*/None, /*optionalDtype=*/Type());
|
||||||
|
bool hasSizes;
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
if (succeeded(parser.parseOptionalStar())) {
|
||||||
|
// Unranked.
|
||||||
|
hasSizes = false;
|
||||||
|
} else {
|
||||||
|
// Parse list of sizes.
|
||||||
|
hasSizes = true;
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return Type();
|
||||||
|
for (bool first = true;; first = false) {
|
||||||
|
if (!first) {
|
||||||
|
if (failed(parser.parseOptionalComma())) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (succeeded(parser.parseOptionalQuestion())) {
|
||||||
|
sizes.push_back(-1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t size;
|
||||||
|
auto optionalInt = parser.parseOptionalInteger(size);
|
||||||
|
if (optionalInt.hasValue()) {
|
||||||
|
if (failed(*optionalInt))
|
||||||
|
return Type();
|
||||||
|
sizes.push_back(size);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (parser.parseRSquare()) {
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parser.parseComma())
|
||||||
|
return Type();
|
||||||
|
Type optionalDtype;
|
||||||
|
if (succeeded(parser.parseOptionalKeyword("unk"))) {
|
||||||
|
// Unknown dtype.
|
||||||
|
} else {
|
||||||
|
// Known dtype.
|
||||||
|
if (parser.parseType(optionalDtype))
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
if (parser.parseGreater())
|
||||||
|
return Type();
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes;
|
||||||
|
if (hasSizes)
|
||||||
|
optionalSizes.emplace(sizes);
|
||||||
|
|
||||||
|
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
|
||||||
|
optionalSizes, optionalDtype)))
|
||||||
|
return Type();
|
||||||
|
|
||||||
|
return getTensorType(context, optionalSizes, optionalDtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printTensorType(DialectAsmPrinter &printer,
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalDtype) {
|
||||||
|
if (!optionalSizes && !optionalDtype)
|
||||||
|
return;
|
||||||
|
printer << "<";
|
||||||
|
if (optionalSizes) {
|
||||||
|
printer << "[";
|
||||||
|
for (auto it : llvm::enumerate(*optionalSizes)) {
|
||||||
|
if (it.index() > 0)
|
||||||
|
printer << ",";
|
||||||
|
if (it.value() < 0)
|
||||||
|
printer << "?";
|
||||||
|
else
|
||||||
|
printer << it.value();
|
||||||
|
}
|
||||||
|
printer << "]";
|
||||||
|
} else {
|
||||||
|
printer << "*";
|
||||||
|
}
|
||||||
|
printer << ",";
|
||||||
|
if (optionalDtype)
|
||||||
|
printer.printType(optionalDtype);
|
||||||
|
else
|
||||||
|
printer << "unk";
|
||||||
|
printer << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NonValueTensorType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
|
||||||
|
return ValueTensorType::get(getContext(), getOptionalSizes(),
|
||||||
|
getOptionalDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
NonValueTensorType
|
||||||
|
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||||
|
return NonValueTensorType::get(context,
|
||||||
|
/*optionalSizes=*/None,
|
||||||
|
/*optionalDtype=*/Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
|
||||||
|
return NonValueTensorType::get(type.getContext(),
|
||||||
|
type.hasRank() ? type.getShape()
|
||||||
|
: Optional<ArrayRef<int64_t>>(),
|
||||||
|
type.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalDtype) {
|
||||||
|
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
||||||
|
return parseTensorType(
|
||||||
|
context, parser,
|
||||||
|
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalType) {
|
||||||
|
return NonValueTensorType::get(context, optionalSizes, optionalType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
|
||||||
|
printer << "tensor";
|
||||||
|
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ValueTensorType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
|
||||||
|
return NonValueTensorType::get(getContext(), getOptionalSizes(),
|
||||||
|
getOptionalDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueTensorType
|
||||||
|
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
|
||||||
|
return ValueTensorType::get(context,
|
||||||
|
/*optionalSizes=*/None,
|
||||||
|
/*optionalDtype=*/Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
|
||||||
|
return ValueTensorType::get(type.getContext(),
|
||||||
|
type.hasRank() ? type.getShape()
|
||||||
|
: Optional<ArrayRef<int64_t>>(),
|
||||||
|
type.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorType ValueTensorType::toBuiltinTensor() const {
|
||||||
|
if (!hasDtype())
|
||||||
|
return nullptr;
|
||||||
|
if (!hasSizes())
|
||||||
|
return UnrankedTensorType::get(getDtype());
|
||||||
|
return RankedTensorType::get(getSizes(), getDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalDtype) {
|
||||||
|
return verifyTensorType(emitError, optionalSizes, optionalDtype);
|
||||||
|
}
|
||||||
|
|
||||||
|
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
|
||||||
|
return parseTensorType(
|
||||||
|
context, parser,
|
||||||
|
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
|
||||||
|
Type optionalType) {
|
||||||
|
return ValueTensorType::get(context, optionalSizes, optionalType);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void ValueTensorType::print(DialectAsmPrinter &printer) const {
|
||||||
|
printer << "vtensor";
|
||||||
|
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
|
||||||
|
}
|
Loading…
Reference in New Issue