//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" namespace mlir { namespace torch { namespace Torch { LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) { // Check the value tensor is ranked as expected by Linalg. // TODO: Remove this check but use a separate verification pass to verify the // invariants expected by later passes. auto isValidLinalgType = [](Type type) { if (isa(type)) return false; auto tensor = dyn_cast(type); return !tensor || tensor.toBuiltinTensor().dyn_cast_or_null(); }; bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && llvm::all_of(op->getResultTypes(), isValidLinalgType); if (!valid) return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg"); return success(); } LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) { Type type = v.getType(); if (isa(type) || isa(type) || isa(type)) return rewriter.notifyMatchFailure(op, "unimplemented None type arg"); return success(); } // Generate IR: dim = dim >= 0 ? dim : dim + inputRank Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, Value inputRank) { assert(isa(dim.getType()) && "dim arg of toPositiveDim must be integer type"); Value dimAddInputRank = b.create(loc, dim, inputRank); Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); Value predDimGEZero = b.create(loc, arith::CmpIPredicate::sge, dim, cst0); Value dimInt = b.create(loc, predDimGEZero, dim, dimAddInputRank); return dimInt; } // Generate IR: assert(dim >= 0 && dim < inputRank) void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { assert(dim.getType().isa() && "dim arg of assertIsValidDim must be integer type"); Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); Value predGEZero = b.create(loc, arith::CmpIPredicate::sge, dim, cst0); b.create( loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero")); Value predLTInputRank = b.create(loc, arith::CmpIPredicate::slt, dim, inputRank); b.create(loc, predLTInputRank, b.getStringAttr("dim must be smaller than inputRank")); } // Hack to deal with the Torch list type arguments which is not supported end // to end. Constant values can be be extracted directly and non constant // list values are not supported. // TODO: loose this constraint when properly support list type bool isConstantIntListMatching(Value value, SmallVectorImpl &expects) { SmallVector intValues; if (!matchPattern(value, m_TorchListOfConstantInts(intValues))) return false; if (intValues.size() != expects.size()) return false; for (auto it : llvm::zip(intValues, expects)) { if (std::get<0>(it) != std::get<1>(it)) return false; } return true; } void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, Value rhsDim) { Type lhsType = lhsDim.getType(); Type rhsType = rhsDim.getType(); auto checkIntOrIndex = [](Type type) { assert((isa(type) || isa(type)) && "must be either integer or index type"); }; checkIntOrIndex(lhsType); checkIntOrIndex(rhsType); Value lhsDimInt = lhsType.isIndex() ? castIndexToInt64(b, loc, lhsDim) : lhsDim; Value rhsDimInt = rhsType.isIndex() ? castIndexToInt64(b, loc, rhsDim) : rhsDim; Value contractingDimEqual = b.create( loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt); b.create(loc, contractingDimEqual, b.getStringAttr("mismatching contracting dimension")); } // Creates a tensor with required `sizes` and `elemTy` and fills it with // initElem. Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy, Value initElem) { Value initTensor = b.create(loc, getAsOpFoldResult(sizes), elemTy); return b.create(loc, initElem, initTensor).getResult(0); } Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = b.create(loc, getAsOpFoldResult(sizes), elemTy); RankedTensorType type = cast(initTensor.getType()); Value c0 = b.create(loc, b.getZeroAttr(type.getElementType())); return b.create(loc, c0, initTensor).getResult(0); } Value castIntToIndex(OpBuilder &b, Location loc, Value v) { assert(v.getType().isa() && "must be called with integer type"); return b.create(loc, b.getIndexType(), v); } Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { assert(idx.getType().isa() && "must be called with integer type"); return b.create(loc, b.getI64Type(), idx); } SmallVector castIntVectorToIndexVector(OpBuilder &b, Location loc, SmallVectorImpl &intValues) { SmallVector indexValues; for (Value v : intValues) indexValues.push_back(castIntToIndex(b, loc, v)); return indexValues; } SmallVector castIndexVectorToInt64Vector(OpBuilder &b, Location loc, SmallVectorImpl &indexValues) { SmallVector intValues; for (Value v : indexValues) intValues.push_back(castIndexToInt64(b, loc, v)); return intValues; } Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { return b.createOrFold(loc, v, dim); } SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, Value tensor, int dim) { RankedTensorType type = cast(tensor.getType()); assert(dim < type.getRank() && "The given dim must be smaller than tensor rank"); (void)type; SmallVector sizes; for (int i = 0; i <= dim; i++) sizes.push_back(getDimOp(b, loc, tensor, i)); return sizes; } SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor) { RankedTensorType type = cast(tensor.getType()); return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1); } Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { SmallVector sizes(getTensorSizes(b, loc, tensor)); Value productResult = b.create(loc, b.getIndexAttr(1)); for (Value size : sizes) productResult = b.create(loc, productResult, size); return castIndexToInt64(b, loc, productResult); } // Creates a constant of type `elemType` with value `val`. Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { TypedAttr attr = {}; if (isa(elemType)) attr = b.getFloatAttr(elemType, val); if (isa(elemType)) attr = b.getIndexAttr(val); if (isa(elemType)) attr = b.getIntegerAttr(elemType, APInt(cast(elemType).getWidth(), val)); if (!attr) return nullptr; return b.create(loc, elemType, attr); } SmallVector getAsConstantIntValues(OpBuilder &b, Location loc, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { return b.create(loc, b.getIntegerAttr(b.getI64Type(), val)); })); } SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { return b.create(loc, b.getIndexAttr(val)); })); } // This is a temporary solution to deal with types that are not fully supported // like list, dict. For those container tyes, this helper can be used to // convert their elements to valid target type. // TODO: remove this when list gets full support. SmallVector getTypeConvertedValues(OpBuilder &b, Location loc, const TypeConverter *converter, SmallVectorImpl &vs) { return llvm::to_vector<4>(llvm::map_range(vs, [&](Value v) { return converter->materializeTargetConversion( b, loc, converter->convertType(v.getType()), v); })); } mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, mlir::Type elementType, mlir::Attribute encoding) { return mlir::RankedTensorType::get(makeShapeLLVMCompatible(shape), elementType, encoding); } static std::optional getIntegerValue(Value scalar) { if (auto constOp = scalar.getDefiningOp()) { return std::optional(constOp.getValue()); } return std::optional(); } // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype, std::optional dstOriginalDtype, std::optional originalScalar) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; auto isByteOrChar = [](Type type) { if (auto integerTy = dyn_cast(type)) { return integerTy.getWidth() == 8; } return false; }; // We support conversion to Byte dtype only if the original scalar is an // integer constant with value lying between 0 - 63. if (isByteOrChar(dtype)) { if (!dstOriginalDtype.has_value()) { mlir::emitError(loc) << "unimplemented: for conversion to byte or char type " "dstOriginalDtype has to be passed to convertScalarToDtype"; return nullptr; } if (dstOriginalDtype->isUnsignedInteger()) { if (originalScalar.has_value()) { std::optional optConstVal = getIntegerValue(originalScalar.value()); if (optConstVal.has_value()) { int64_t constVal = optConstVal.value(); if (constVal < 0 || constVal > 63) { // Do the conversion only if the original integer value is between // 0 - 63. mlir::emitError(loc) << "unsupported: conversion to byte type for " "convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; return nullptr; } } } } } // If the dtype is i1, i.e., a boolean type. if (dtype.isSignlessInteger(1)) { Type scalarType = scalar.getType(); Value cstZero = b.create(loc, b.getZeroAttr(scalarType)); if (isa(scalarType)) { return b.create(loc, arith::CmpFPredicate::UNE, scalar, cstZero); } else if (isa(scalarType)) { return b.create(loc, arith::CmpIPredicate::ne, scalar, cstZero); } else { mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; return nullptr; } } if (auto dtypeFloat = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) { if (scalarFloat.getWidth() > dtypeFloat.getWidth()) return b.create(loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. return b.create(loc, dtype, scalar); } assert(isa(scalarType)); if (scalarType.isSignlessInteger(1) || (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) return b.create(loc, dtype, scalar); // It's safe to use SIToFPOp because ui8/si8 are the only ones where // unsigned handling is needed, and we checked for that case above. return b.create(loc, dtype, scalar); } if (auto dtypeInteger = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) return b.create(loc, dtype, scalar); assert(isa(scalarType)); auto scalarInteger = cast(scalarType); if (scalarInteger.getWidth() > dtypeInteger.getWidth()) return b.create(loc, dtype, scalar); if (scalarType.isSignlessInteger(1) || (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) return b.create(loc, dtype, scalar); // Only scalarInteger width < dtypeInteger width can reach here. // It's safe to use ExtSIOp here because ui8/si8 are the only ones where // unsigned handling is needed, and we checked for that case above. return b.create(loc, dtype, scalar); } if (auto dtypeComplex = dyn_cast(dtype)) { if (auto scalarComplex = dyn_cast(scalarType)) { auto dtypeElemType = dtypeComplex.getElementType(); // Extract the real and imaginary parts of the scalar. // Cast them to the target element type, and create a new complex // value with the target complex type. Value realVal = b.create(loc, scalar); Value imgVal = b.create(loc, scalar); realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType); imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType); return b.create(loc, dtypeComplex, realVal, imgVal); } mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; } llvm_unreachable("convertScalarToDtype should handle all the types"); } Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize) { if (torchOptionalInt.getType().isa()) return defaultValue; auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); Value positiveDim = toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); // positiveDim < 0 ? 0 : positiveDim Value cst0 = rewriter.create( loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); Value predDimSltZero = rewriter.create( loc, arith::CmpIPredicate::slt, positiveDim, cst0); Value atLeastZero = rewriter.create(loc, predDimSltZero, cst0, positiveDim); // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero Value sgtDimSize = rewriter.create( loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); Value boundedByDimSize = rewriter.create( loc, sgtDimSize, dimSizeAsInt, atLeastZero); return castIntToIndex(rewriter, loc, boundedByDimSize); } } // namespace Torch } // namespace torch } // namespace mlir