//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { // Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an // linalg.indexed_generic op, producing two output buffers. // // The first output buffer contains the maximum (minium) value found. It is // initialized to the minimum (maximum) representable value of the input // element type. // // The second output buffer contains the index of the found maximum (minimum) // value. It is initialized to 0 and is resulting integer type. // // The indexed_generic op updates both the maximum (minimum) value and index // if the current value exceeds the running max (min). template class ConvertAtenMinMaxDimOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpConversionPattern::getTypeConverter; using OpAdaptor = typename OpTy::Adaptor; LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { static_assert(std::is_same() || std::is_same()); constexpr bool isMax = std::is_same(); const llvm::StringRef opName = op->getName().getStringRef(); Location loc = op.getLoc(); Value input = adaptor.getSelf(); auto typec = this->getTypeConverter(); auto valResultType = cast(typec->convertType(op.getResult(0).getType())); auto idxResultType = cast(typec->convertType(op.getResult(1).getType())); RankedTensorType inputType = input.getType().template cast(); Type idxElementType = getElementTypeOrSelf(typec->convertType(idxResultType)); if (!idxElementType.isa()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires integer-like result type"); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) return rewriter.notifyMatchFailure( op, opName + " requires boolean value for keepdim"); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires int value for Dim"); dim = toPositiveDim(dim, inputType.getRank()); if (!isValidDim(dim, inputType.getRank())) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); Type inElementType = inputType.getElementType(); bool isUnsigned = false; if (!inElementType.isa()) { if (inElementType.isa()) { auto integerTy = op.getSelf() .getType() .template cast() .getDtype() .template dyn_cast(); isUnsigned = integerTy.isUnsigned(); } else { return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires Float or Integer " "input element type"); } } // Constant op to account for the reduction along dim. SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); } } // First fill the output buffer for the index. Value filledTensorIdx = createZeroInitTensor(rewriter, loc, resultShape, idxElementType); // Second fill the output buffer for the running max or min. Value initTensorVal = rewriter.create( loc, getAsOpFoldResult(resultShape), inElementType); Value fillValue; if (inElementType.isa()) { fillValue = rewriter.create( loc, rewriter.getFloatAttr( inElementType, APFloat::getInf( inElementType.cast().getFloatSemantics(), /*Negative=*/isMax))); } else if (!isUnsigned) { auto width = inElementType.cast().getWidth(); auto init = isMax ? APSInt::getSignedMinValue(width) : APSInt::getSignedMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); } else if (isUnsigned) { auto width = inElementType.cast().getWidth(); auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); } Value filledTensorVal = rewriter.create(loc, fillValue, initTensorVal).result(); SmallVector iteratorTypes( inputType.getRank(), utils::IteratorType::parallel); iteratorTypes[dim] = utils::IteratorType::reduction; // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. SmallVector exprs; SmallVector resultExprs; for (auto size : llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { exprs.push_back(rewriter.getAffineDimExpr(size.index())); if (unsigned(dim) != size.index()) resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); } auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, rewriter.getContext()); auto linalgOp = rewriter.create( loc, ArrayRef({filledTensorVal.getType(), filledTensorIdx.getType()}), input, ValueRange({filledTensorVal, filledTensorIdx}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value newValue = blockArgs[0]; Value oldValue = blockArgs[1]; Value oldIndex = blockArgs[2]; Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, dim)); Value resultVal, predicate; if (inElementType.isa()) { arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; resultVal = rewriter.create( nestedLoc, newValue, oldValue); } else { predType = arith::CmpFPredicate::OLT; resultVal = rewriter.create( nestedLoc, newValue, oldValue); } predicate = rewriter.create(nestedLoc, predType, newValue, oldValue); } else { arith::CmpIPredicate predType; if (isMax) { predType = isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; if (isUnsigned) { resultVal = rewriter.create(nestedLoc, newValue, oldValue); } else { resultVal = rewriter.create(nestedLoc, newValue, oldValue); } } else { predType = isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; if (isUnsigned) { resultVal = rewriter.create(nestedLoc, newValue, oldValue); } else { resultVal = rewriter.create(nestedLoc, newValue, oldValue); } } predicate = rewriter.create(nestedLoc, predType, newValue, oldValue); } auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( nestedLoc, ValueRange({resultVal, resultIndex})); }); if (!keepDim) { Value rVal = rewriter.create(loc, valResultType, linalgOp.getResult(0)); Value rIdx = rewriter.create(loc, idxResultType, linalgOp.getResult(1)); llvm::SmallVector res{rVal, rIdx}; rewriter.replaceOp(op, res); return success(); } llvm::SmallVector valShape(valResultType.getShape()); llvm::SmallVector idxShape(idxResultType.getShape()); for (int i = dim, s = valShape.size() - 1; i < s; ++i) { valShape[i] = valShape[i + 1]; idxShape[i] = idxShape[i + 1]; } valShape.resize(valShape.size() - 1); idxShape.resize(idxShape.size() - 1); Value rVal = rewriter.create( loc, valResultType.clone(valShape), linalgOp.getResult(0)); Value rIdx = rewriter.create( loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); SmallVector reassociation(valShape.size()); if (reassociation.size() > 0) { for (int i = 0; i < dim; ++i) reassociation[i].push_back(i); reassociation[std::max(0, dim - 1)].push_back(dim); for (int i = dim, s = reassociation.size(); i < s; ++i) reassociation[i].push_back(i + 1); } valShape.push_back(0); idxShape.push_back(0); for (int i = dim, s = valShape.size() - 1; i < s; ++i) { valShape[i + 1] = valShape[i]; idxShape[i + 1] = idxShape[i]; } valShape[dim] = 1; idxShape[dim] = 1; Value unsqueezeVal = rewriter.create( loc, valResultType, rVal, reassociation); Value unsqueezeIdx = rewriter.create( loc, idxResultType, rIdx, reassociation); llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; rewriter.replaceOp(op, unsqueezes); return success(); } }; } // namespace static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem, Type resultElementType) { if (elem.getType().isa()) { return b.create(loc, elem); } Value self = convertScalarToDtype(b, loc, elem, resultElementType); return b.create(loc, self); } static Value createInitElementForReduceOp(OpBuilder &b, Location loc, Operation *op, Type elementType) { if (isa(op)) return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { if (elementType.isa()) return b.create(loc, b.getFloatAttr(elementType, 1.0)); else if (elementType.isa()) return b.create(loc, b.getIntegerAttr(elementType, 1)); } if (isa(op)) { if (elementType.isa()) return b.create( loc, b.getFloatAttr( elementType, APFloat::getInf( elementType.cast().getFloatSemantics(), /*Negative=*/true))); else if (elementType.isa() && elementType.getIntOrFloatBitWidth() != 8) return b.create( loc, b.getIntegerAttr(elementType, APSInt::getSignedMinValue( elementType.getIntOrFloatBitWidth()))); } if (isa(op)) { if (elementType.isa()) return b.create( loc, b.getFloatAttr( elementType, APFloat::getInf( elementType.cast().getFloatSemantics(), /*Negative=*/false))); else if (elementType.isa() && elementType.getIntOrFloatBitWidth() != 8) return b.create( loc, b.getIntegerAttr(elementType, APSInt::getSignedMaxValue( elementType.getIntOrFloatBitWidth()))); } if (isa(op) || isa(op) || isa(op)) return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { return b.create(loc, b.getBoolAttr(true)); } op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ArrayRef operands, Type resultElementType) { if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) return b.create(loc, self, result); else if (resultElementType.isa()) return b.create(loc, self, result); } else if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) return b.create(loc, self, result); else if (resultElementType.isa()) return b.create(loc, self, result); } else if (auto max = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = max.getSelf() .getType() .cast() .getDtype() .dyn_cast(); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) return b.create(loc, self, result); } } else if (auto min = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (resultElementType.isa()) return b.create(loc, self, result); else if (resultElementType.isa()) { IntegerType intType = min.getSelf() .getType() .cast() .getDtype() .dyn_cast(); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) return b.create(loc, self, result); } } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `p` is zero or one. Value elem = payloadArgs[0]; Value result = payloadArgs[1]; AtenNormScalarOp::Adaptor adaptor(operands); Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, p); return b.create(loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. Value elem = payloadArgs[0]; Value result = payloadArgs[1]; AtenLinalgVectorNormOp::Adaptor adaptor(operands); Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; TypedAttr twoAttr = b.getFloatAttr(resultElementType, 2.0); auto ord = b.create(loc, twoAttr); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; } namespace { class ConvertReductionOp : public ConversionPattern { private: /// Given a reduction operation that has the `keepdim` attribute and the /// (optional) `dim` attribute, return the source tensor operand and the /// literal values of the attributes or failure otherwise. template FailureOr computeReductionOpInfoForDimVariantOp( T op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; typename T::Adaptor adaptor(operands); opInfo.tensorOperand = adaptor.getSelf(); auto inputType = opInfo.tensorOperand.getType().cast(); if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim))) return rewriter.notifyMatchFailure(op, "`keepdim` must be a constant bool"); SmallVector dimList; int64_t dim; bool isNoneOrEmptyDimList = op.getDim().getType().template isa(); if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { // Fix negative dimensions, if any, before adding to the list. for (int64_t dim : dimList) { dim = toPositiveDim(dim, inputType.getRank()); // Drop invalid dimensions if (isValidDim(dim, inputType.getRank())) opInfo.dimSet.insert(dim); } if (dimList.empty()) isNoneOrEmptyDimList = true; } else if (matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { dim = toPositiveDim(dim, inputType.getRank()); if (!isValidDim(dim, inputType.getRank())) return rewriter.notifyMatchFailure( op, "`dim` argument must be valid, invalid received."); opInfo.dimSet.insert(dim); } else if (!isNoneOrEmptyDimList) { return rewriter.notifyMatchFailure( op, "`dim` argument must be a constant int list or None"); } if (isNoneOrEmptyDimList) { // If no dimensions were specified, reduce along all dimensions for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); } return opInfo; } /// Given a reduction operation, return the source tensor operand and the /// literal values of the `keepdim` and `dim` attributes, if any, or failure /// otherwise. FailureOr computeReductionOpInfo(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the // dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); return opInfo; } if (auto sumOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(sumOp, operands, rewriter); if (auto prodOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(prodOp, operands, rewriter); if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); if (auto allOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter); return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } /// Generate a linalg.generic operation for pointwise exponentiation of each /// element. Value createElementwiseExp(Location loc, Type elemType, Value exponent, Value inputTensor, const torch_to_linalg::ReductionOpInfo &opInfo, ConversionPatternRewriter &rewriter) const { bool err = false; auto powBodyBuilder = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType); auto result = builder.create(loc, elem, exponent); if (result) builder.create(loc, Value{result}); err = !result; }; Value powOp = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {inputTensor}, elemType, powBodyBuilder); return err ? Value{} : powOp; } template FailureOr createSecondReductionForNormOp(Location loc, Type elemType, TOp op, Value ordOp, Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, ConversionPatternRewriter &rewriter) const { // Cast `ord` to float so that we can readily pass it math.powf. Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); // TODO: Add support for ord = {0, +inf, -inf}. auto epsilon = 1e-5; auto ordLiteral = 0.0; if (matchPattern(ordValue, m_TorchConstantFloat(&ordLiteral)) && fabs(ordLiteral) < epsilon) return rewriter.notifyMatchFailure(op, "unimplemented: L0 norm"); if (std::isinf(ordLiteral)) return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf"); // Raise each summed value to the inverse of the order of the norm. TypedAttr oneAttr = rewriter.getFloatAttr(elemType, 1.0); auto oneValue = rewriter.create(loc, oneAttr); auto inverseOrdValue = rewriter.create(loc, oneValue, ordValue); // Use the results of the first reduction operation from above to generate // a second reduction operation. Value reduceOp = createElementwiseExp(loc, elemType, inverseOrdValue, firstReduction, opInfo, rewriter); if (!reduceOp) return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for element-wise " "exponentiation"); return reduceOp; } /// Generate a linalg.generic operation for a reduction. Value createReductionOp(Location loc, Type elemType, Operation *op, ArrayRef operands, const torch_to_linalg::ReductionOpInfo &opInfo, ConversionPatternRewriter &rewriter) const { bool err = false; auto reductionBodyBuilder = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs, op, operands, elemType); if (result) builder.create(loc, result); err = !result; }; Value initElem = createInitElementForReduceOp(rewriter, loc, op, elemType); Value reduceOp = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, initElem, reductionBodyBuilder); return err ? Value{} : reduceOp; } /// Depending on the operation, check validity of the result's element type. LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { if ((isa(op) || isa(op) || isa(op)) && !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); if (isa(op) && elemType.isa() && elemType.getIntOrFloatBitWidth() == 8) return rewriter.notifyMatchFailure(op, "uint8 is not supported"); // No checks for all other reduction operations return success(); } public: ConvertReductionOp(TypeConverter &typeConverter, MLIRContext *context) : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return rewriter.notifyMatchFailure( op, "invalid operand or result types to use with linalg on tensors"); FailureOr opInfo = computeReductionOpInfo(op, operands, rewriter); if (failed(opInfo)) return opInfo; Location loc = op->getLoc(); auto resultType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); Type elemType = resultType.getElementType(); LogicalResult elemTypeCheck = validateReductionElementType(op, elemType, rewriter); if (failed(elemTypeCheck)) return elemTypeCheck; Value reduceOp = createReductionOp(loc, elemType, op, operands, *opInfo, rewriter); if (!reduceOp) return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for reduction"); // If this is aten.norm.Scalar op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (isa(op)) { AtenNormScalarOp::Adaptor adaptor(operands); FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; reduceOp = *secondReduceOp; } // If this is aten.linalg_vector_norm op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (auto normOp = dyn_cast(op)) { AtenLinalgVectorNormOp::Adaptor adaptor(operands); FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; reduceOp = *secondReduceOp; } // If it is aten.frobenius_norm.dim op, take the square root of reduceOp as // the final result if (auto normOp = dyn_cast(op)) { auto halfAttr = rewriter.getFloatAttr(elemType, 0.5); auto exp = rewriter.create(loc, halfAttr); reduceOp = createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter); } rewriter.replaceOpWithNewOp(op, resultType, reduceOp); return success(); } }; } // namespace void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); }