//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { template class QuantizeOperands : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { llvm::SmallVector operands(op->getOperands()); bool dequanted = false; auto f = [&dequanted](Value operand) { if (auto dequant = operand.getDefiningOp()) { operand = dequant.getOperand(); dequanted = true; } if (auto dequant = operand.getDefiningOp()) { operand = dequant.getOperand(); dequanted = true; } return operand; }; operands[0] = f(operands[0]); operands[1] = f(operands[1]); if (!dequanted) { return rewriter.notifyMatchFailure(op, "no dequantizations found"); } rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } }; template class QuantizeTransposedOperands : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { llvm::SmallVector operands(op->getOperands()); unsigned numOperands = operands.size(); bool dequanted = false; for (unsigned i = 0; i < numOperands; i++) { if (auto trans = operands[i].getDefiningOp()) { auto transOperands = trans.getOperands(); Value dequantOperand; if (auto dequant = transOperands[0].getDefiningOp()) { dequantOperand = dequant.getOperand(); if (auto quant = dequantOperand .getDefiningOp()) { auto quantOperands = quant.getOperands(); auto qType = quantOperands[0] .getType() .cast() .getOptionalDtype(); auto torchQType = quant.getType().cast().getOptionalDtype(); auto transQTy = rewriter.getType(trans.getResult() .getType() .cast() .getOptionalSizes(), qType); auto newQuantTy = rewriter.getType(trans.getResult() .getType() .cast() .getOptionalSizes(), torchQType); Value newTrans = rewriter.create( op.getLoc(), transQTy, quantOperands[0], transOperands[1], transOperands[2]); Value newQuant = rewriter.create( op.getLoc(), newQuantTy, newTrans, quantOperands[1], quantOperands[2]); operands[i] = newQuant; dequanted = true; } } } } if (!dequanted) { return rewriter.notifyMatchFailure( op, "no dequantized transpose inputs found."); } rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } }; template class QuantizeBias : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { llvm::SmallVector operands(op->getOperands()); if (operands.size() < 3) return failure(); Value lhsScale; if (auto qLhs = operands[0].getDefiningOp()) lhsScale = qLhs.getScale(); Value rhsScale; if (auto qRhs = operands[1].getDefiningOp()) rhsScale = qRhs.getScale(); if (!rhsScale || !lhsScale) return failure(); auto resultTy = cast(op.getType()); if (!isa(resultTy.getDtype())) return failure(); Value bias = operands[2]; auto biasTy = bias.getType().dyn_cast(); if (biasTy) { auto biasETy = biasTy.getOptionalDtype(); if (!biasETy || !isa(biasETy)) return failure(); } Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); Value zero = rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); if (biasTy) { auto newBiasTy = rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); bias = rewriter.create( op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); bias = rewriter.create( op.getLoc(), rewriter.getType( biasTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)), bias); operands[2] = bias; } auto convTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); auto conv = rewriter.create(op.getLoc(), convTy, operands); auto convQTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto makeOut = rewriter.create( op.getLoc(), convQTy, conv, biasScale, zero); rewriter.replaceOpWithNewOp(op, op.getType(), makeOut); return success(); } }; template class QuantizeAccumulator : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto lhs = op.getOperand(0); auto rhs = op.getOperand(1); auto resultTy = dyn_cast_or_null(op.getType()); if (!resultTy || !resultTy.hasDtype()) return failure(); Type resultETy = resultTy.getDtype(); if (!isa(resultETy)) return failure(); Value lhsScale; if (auto defining = lhs.template getDefiningOp()) { lhsScale = defining.getScale(); } Value rhsScale; if (auto defining = rhs.template getDefiningOp()) { rhsScale = defining.getScale(); } if (!lhsScale || !rhsScale) return failure(); // Quantize the bias input to the expected result: Value zero = rewriter.create( op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); // Update the quantied type: llvm::SmallVector operands(op.getOperands()); auto newResultTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto conv = rewriter.create(op.getLoc(), newResultTy, operands); // Attach the quantize information to the resulting qint32: auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); auto intRepr = rewriter.create(op.getLoc(), intReprTy, conv); auto quantTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto quant = rewriter.create( op.getLoc(), quantTy, intRepr, biasScale, zero); auto dequant = rewriter.create(op.getLoc(), resultTy, quant); rewriter.replaceOp(op, dequant); return success(); } }; template class RemoveUnused : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); if (result.use_empty()) { op.erase(); return success(); } return failure(); } }; class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.insert< RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperands, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, QuantizeBias>( context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); } } }; } // namespace std::unique_ptr> mlir::torch::Torch::createFuseQuantizedOpsPass() { return std::make_unique(); }