
392 lines
14 KiB

// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "npcomp/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyOps.h"
#include "npcomp/Dialect/Torch/IR/TorchOps.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::Torch;
// Dialect Interfaces
namespace {
struct TorchInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
} // end anonymous namespace
// Tablegen Type Definitions
#include "npcomp/Dialect/Torch/IR/"
void TorchDialect::initialize() {
#define GET_OP_LIST
#include "npcomp/Dialect/Torch/IR/"
#include "npcomp/Dialect/Torch/IR/"
Type TorchDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
Type type;
if (generatedTypeParser(getContext(), parser, keyword, type).hasValue())
return type;
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
<< keyword << "'";
return Type();
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
if (failed(generatedTypePrinter(type, printer)))
llvm_unreachable("unknown 'torch' type");
// BaseTensorType
// TODO: Move most of this to a new file TorchTypes.cpp.
static bool isValidTorchDtype(Type dtype) {
// Torch quantized types.
if (dtype.isa<Torch::QInt8Type>())
return true;
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
return true;
// Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
if (type.getWidth() == width)
return true;
if (type.isUnsigned()) {
return type.getWidth() == 8;
return false;
bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const {
return getWithSizesAndDtype(other.getOptionalSizes(),
Type BaseTensorType::getWithSizesAndDtype(
Optional<ArrayRef<int64_t>> optionalSizes, Type optionalDtype) const {
if (isa<NonValueTensorType>())
return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype);
if (isa<ValueTensorType>())
return ValueTensorType::get(getContext(), optionalSizes, optionalDtype);
llvm_unreachable("not a BaseTensorType!");
ValueTensorType BaseTensorType::getWithValueSemantics() const {
if (auto tensor = dyn_cast<NonValueTensorType>())
return tensor.getWithValueSemantics();
if (auto tensor = dyn_cast<ValueTensorType>())
return tensor;
llvm_unreachable("not a BaseTensorType!");
static LogicalResult
verifyTensorType(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (optionalDtype && !isValidTorchDtype(optionalDtype)) {
emitError() << "invalid dtype " << optionalDtype
<< " for !torch.tensor type";
return failure();
return success();
Type parseTensorType(MLIRContext *context, DialectAsmParser &parser,
GetTensorTypeFn getTensorType) {
llvm::SMLoc startLoc = parser.getCurrentLocation();
if (parser.parseOptionalLess())
return getTensorType(context,
/*optionalSizes=*/None, /*optionalDtype=*/Type());
bool hasSizes;
SmallVector<int64_t> sizes;
if (succeeded(parser.parseOptionalStar())) {
// Unranked.
hasSizes = false;
} else {
// Parse list of sizes.
hasSizes = true;
if (parser.parseLSquare())
return Type();
for (bool first = true;; first = false) {
if (!first) {
if (failed(parser.parseOptionalComma())) {
if (succeeded(parser.parseOptionalQuestion())) {
int64_t size;
auto optionalInt = parser.parseOptionalInteger(size);
if (optionalInt.hasValue()) {
if (failed(*optionalInt))
return Type();
if (parser.parseRSquare()) {
return Type();
if (parser.parseComma())
return Type();
Type optionalDtype;
if (succeeded(parser.parseOptionalKeyword("unk"))) {
// Unknown dtype.
} else {
// Known dtype.
if (parser.parseType(optionalDtype))
return Type();
if (parser.parseGreater())
return Type();
Optional<ArrayRef<int64_t>> optionalSizes;
if (hasSizes)
if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); },
optionalSizes, optionalDtype)))
return Type();
return getTensorType(context, optionalSizes, optionalDtype);
static void printTensorType(DialectAsmPrinter &printer,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
if (!optionalSizes && !optionalDtype)
printer << "<";
if (optionalSizes) {
printer << "[";
for (auto it : llvm::enumerate(*optionalSizes)) {
if (it.index() > 0)
printer << ",";
if (it.value() < 0)
printer << "?";
printer << it.value();
printer << "]";
} else {
printer << "*";
printer << ",";
if (optionalDtype)
printer << "unk";
printer << ">";
// NonValueTensorType
ValueTensorType NonValueTensorType::getWithValueSemantics() const {
return ValueTensorType::get(getContext(), getOptionalSizes(),
NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return NonValueTensorType::get(context,
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
return NonValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
Type NonValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return NonValueTensorType::get(context, optionalSizes, optionalType);
void NonValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "tensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
// ValueTensorType
NonValueTensorType ValueTensorType::getWithoutValueSemantics() const {
return NonValueTensorType::get(getContext(), getOptionalSizes(),
ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
return ValueTensorType::get(context,
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
return ValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
TensorType ValueTensorType::toBuiltinTensor() const {
if (!hasDtype())
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
return RankedTensorType::get(getSizes(), getDtype());
ValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalDtype) {
return verifyTensorType(emitError, optionalSizes, optionalDtype);
Type ValueTensorType::parse(MLIRContext *context, DialectAsmParser &parser) {
return parseTensorType(
context, parser,
[](MLIRContext *context, Optional<ArrayRef<int64_t>> optionalSizes,
Type optionalType) {
return ValueTensorType::get(context, optionalSizes, optionalType);
void ValueTensorType::print(DialectAsmPrinter &printer) const {
printer << "vtensor";
printTensorType(printer, getOptionalSizes(), getOptionalDtype());
// Dialect-level verifiers.
LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute namedAttr) {
if (namedAttr.first == "torch.type_bound") {
auto func = dyn_cast<FuncOp>(op);
if (!func)
return op->emitError() << "'torch.type_bound' must be attached to a func";
TypeAttr attr = namedAttr.second.dyn_cast<TypeAttr>();
if (!attr)
return op->emitError() << "'torch.type_bound' must be TypeAttr";
auto type = attr.getValue().dyn_cast<BaseTensorType>();
if (!type)
return op->emitError() << "'torch.type_bound' must be of "
"!torch.tensor/!torch.vtensor type";
if (!func.getType().getInput(argIndex).isa<BaseTensorType>())
return op->emitError() << "'torch.type_bound' must be attached to an "
"argument of !torch.tensor/!torch.vtensor type";
return success();
return op->emitError() << "unknown region arg attribute '" << namedAttr.first
<< "'";
// Constant materializer.
Operation *TorchDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
// Bool (i1 -> !basicpy.BoolType).
if (type.isa<Basicpy::BoolType>()) {
auto i1Value = value.dyn_cast<IntegerAttr>();
if (i1Value && i1Value.getType().getIntOrFloatBitWidth() == 1)
return builder.create<Basicpy::BoolConstantOp>(loc, type, i1Value);
// i64 is how we model TorchScript's "scalar integer type" (we could have a
// proper ! type in theory). None of our canonicalizers should be
// creating any other integer type (except perhaps i1 after we resolve that
// situation). All other integer types live inside tensors (that is, they are
// never the direct result of an operation, and are thus never candidates for
// constant materialization).
if (auto integerType = type.dyn_cast<IntegerType>()) {
if (integerType.getWidth() == 64)
return builder.create<ConstantOp>(loc, value);
return nullptr;