//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { template struct QuantInfo { static constexpr unsigned operandsToQuantize[2] = {0, 1}; }; template <> struct QuantInfo { static constexpr unsigned operandsToQuantize[1] = {0}; }; // A QCommutingOp is an Op satisfying: // 1. Has at most one tensor operand at index 0 // 2. Has a single output, which is a tensor // 3. Satisfies the commutation relation: // [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant] // where MPTQT = "Aten_MakePerTensorQuantizedTensorOp" // and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp" bool isQCommutingOp(mlir::Operation *op) { // if adding a new commuting op here, be sure to add a // RemoveUnused pattern for that op to clean up afterwards return llvm::isa( op); } // The following conversion takes patterns of the form [op0 -> MPTQT -> dequant // -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> // Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops // {Op1,Op2,...,Opk} with k <= depth. // With depth = 0, this conversion will simply fuse any immediately quantizable // operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int // operands)] template class QuantizeOperandsPastCommutingOps : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); bool dequanted = false; for (unsigned i : QuantInfo::operandsToQuantize) { Value operand = operands[i]; std::stack commutingOpStack; Value dequantOpd, MPTQTOpd, scale, zeroPoint; for (unsigned k = 0; k < depth + 1; k++) { auto currOp = operand.getDefiningOp(); // Case 0 : currOp is a nullptr (e.g., operand is a block argument) if (!currOp) break; // Case 1 : currOp is a q commuting op (continue loop) if (isQCommutingOp(currOp)) { commutingOpStack.push(currOp); // set operand to currOp for next k-iteration operand = currOp->getOperand(0); continue; } // Case 2 : currOp is a dequant op (end loop) if (llvm::isa(currOp)) { dequantOpd = currOp->getOperand(0); auto MPTQTOp = dequantOpd.getDefiningOp(); MPTQTOpd = MPTQTOp.getOperand(0); scale = MPTQTOp.getOperand(1); zeroPoint = MPTQTOp.getOperand(2); } // either a dequant was found or chain broken, so break loop break; } // move to next operand if this trace was unsuccessful if (!MPTQTOpd) continue; // a successful trace occured, so set dequant to true dequanted = true; // rewrite stack Value oldOpd = MPTQTOpd; Type intDType = cast(MPTQTOpd.getType()).getOptionalDtype(); while (!commutingOpStack.empty()) { // get front of the commuting op stack and replace its first operand // with oldOpd auto currOp = commutingOpStack.top(); commutingOpStack.pop(); llvm::SmallVector currOperands(currOp->getOperands()); currOperands[0] = oldOpd; // pad ops aren't quite commuting, so we include some extra logic to // quantize the padding value if (isa(currOp)) { Value floatPadValue = currOperands.back(); Value quantPadValue; if (isa(floatPadValue.getType())) quantPadValue = rewriter.create(loc, zeroPoint); else { floatPadValue = rewriter.create(loc, floatPadValue); quantPadValue = rewriter.create( loc, floatPadValue, scale); quantPadValue = rewriter.create( loc, quantPadValue, zeroPoint); } // clamp pad value to qint range if (auto intType = dyn_cast(intDType)) { bool isSigned = intType.isSignedInteger(); int64_t width = intType.getWidth(); assert(width < 64 && "quantized int bitwidth should be less than 64"); int64_t minInt = isSigned ? -(1 << (width - 1)) : 0; int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1); Value minQValueFloat = rewriter.create( loc, rewriter.getF64FloatAttr(minInt)); Value maxQValueFloat = rewriter.create( loc, rewriter.getF64FloatAttr(maxInt)); SmallVector emptyShape; auto floatTensorType = rewriter.getType( emptyShape, rewriter.getF64Type()); Value quantPadValueTensor = createRank0Tensor( rewriter, loc, floatTensorType, quantPadValue); Value clampedTensor = rewriter.create( loc, floatTensorType, quantPadValueTensor, minQValueFloat, maxQValueFloat); quantPadValue = rewriter.create( loc, rewriter.getType(), clampedTensor); } // quantPadValue is a float, but will get converted/truncated currOperands.back() = quantPadValue; } // get new result type auto oldType = cast(currOp->getResultTypes()[0]); auto intType = rewriter.getType(oldType.getSizes(), intDType); // rewrite currOp to have new operands and result type // store this as oldOpd for next loop oldOpd = rewriter .create(loc, (currOp->getName()).getIdentifier(), currOperands, intType, currOp->getAttrs()) ->getResult(0); } // stack is empty, so oldOpd is now the corrected verion of the // SrcOp's original operand // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); auto qTorchType = cast(dequantOpd.getType()).getOptionalDtype(); auto newMPTQTType = rewriter.getType( cast(operands[i].getType()).getSizes(), qTorchType); operands[i] = rewriter.create( loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } if (!dequanted) { return rewriter.notifyMatchFailure(op, "No dequantizations found."); } rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } }; template class QuantizeBias : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { llvm::SmallVector operands(op->getOperands()); if (operands.size() < 3) return failure(); Value lhsScale; if (auto qLhs = operands[0].getDefiningOp()) lhsScale = qLhs.getScale(); Value rhsScale; if (auto qRhs = operands[1].getDefiningOp()) rhsScale = qRhs.getScale(); if (!rhsScale || !lhsScale) return failure(); auto resultTy = cast(op.getType()); if (!isa(resultTy.getDtype())) return failure(); Value bias = operands[2]; auto biasTy = dyn_cast(bias.getType()); if (biasTy) { auto biasETy = biasTy.getOptionalDtype(); if (!biasETy || !isa(biasETy)) return failure(); } Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); Value zero = rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); if (biasTy) { auto newBiasTy = rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); bias = rewriter.create( op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); bias = rewriter.create( op.getLoc(), rewriter.getType( biasTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)), bias); operands[2] = bias; } auto convTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); auto conv = rewriter.create(op.getLoc(), convTy, operands); auto convQTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto makeOut = rewriter.create( op.getLoc(), convQTy, conv, biasScale, zero); rewriter.replaceOpWithNewOp(op, op.getType(), makeOut); return success(); } }; template class QuantizeAccumulator : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto lhs = op.getOperand(0); auto rhs = op.getOperand(1); auto resultTy = dyn_cast_or_null(op.getType()); if (!resultTy || !resultTy.hasDtype()) return failure(); Type resultETy = resultTy.getDtype(); if (!isa(resultETy)) return failure(); Value lhsScale; if (auto defining = lhs.template getDefiningOp()) { lhsScale = defining.getScale(); } Value rhsScale; if (auto defining = rhs.template getDefiningOp()) { rhsScale = defining.getScale(); } if (!lhsScale || !rhsScale) return failure(); // Quantize the bias input to the expected result: Value zero = rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); // Update the quantied type: llvm::SmallVector operands(op.getOperands()); auto newResultTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto conv = rewriter.create(op.getLoc(), newResultTy, operands); // Attach the quantize information to the resulting qint32: auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); auto intRepr = rewriter.create(op.getLoc(), intReprTy, conv); auto quantTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto quant = rewriter.create( op.getLoc(), quantTy, intRepr, biasScale, zero); auto dequant = rewriter.create(op.getLoc(), resultTy, quant); rewriter.replaceOp(op, dequant); return success(); } }; // Use for ops which do not manipulate scale/zero point of an input. template class QuantizeResultLikeOperand : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { llvm::SmallVector operands(op->getOperands()); Value input = operands[0]; auto inputType = dyn_cast_or_null(input.getType()); if (!inputType || !inputType.hasDtype()) return failure(); auto qDtype = inputType.getDtype(); auto resultTy = dyn_cast_or_null(op.getType()); if (!resultTy || !resultTy.hasDtype()) return failure(); Type resultETy = resultTy.getDtype(); if (!isa(resultETy)) return failure(); Value inputScale, inputZeroPoint; Type definingOpInputType; if (auto defining = input.template getDefiningOp< Aten_MakePerTensorQuantizedTensorOp>()) { inputScale = defining.getScale(); inputZeroPoint = defining.getZeroPoint(); definingOpInputType = defining.getSelf().getType(); } auto inputIntReprType = dyn_cast_or_null(definingOpInputType); if (!inputScale || !inputZeroPoint || !inputIntReprType || !inputIntReprType.hasDtype()) return failure(); auto intReprDtype = inputIntReprType.getDtype(); // set SrcOp type to use quantized dtype from input auto newResultTy = rewriter.getType(resultTy.getOptionalSizes(), qDtype); auto newResult = rewriter.create(op.getLoc(), newResultTy, operands); // int repr to get non quantized int type result auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), intReprDtype); auto intRepr = rewriter.create(op.getLoc(), intReprTy, newResult); // requantize so the scale and zero-point info can be attached auto quantTy = rewriter.getType(resultTy.getOptionalSizes(), qDtype); auto quant = rewriter.create( op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); // dequant back to original dtype auto dequant = rewriter.create(op.getLoc(), resultTy, quant); rewriter.replaceOp(op, dequant); return success(); } }; template class RemoveUnused : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); if (result.use_empty()) { op.erase(); return success(); } return failure(); } }; class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.insert< RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createFuseQuantizedOpsPass() { return std::make_unique(); }