mirror of https://github.com/llvm/torch-mlir
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing dequantization into dense operations. As a validation we compare to the dequant+mm torch implementation.pull/2803/head
parent
60bf6c25af
commit
f6f890520b
|
@ -1 +1 @@
|
|||
Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b
|
||||
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0
|
|
@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
|||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createMatchQuantizedCustomOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||
|
||||
|
|
|
@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
|||
}];
|
||||
}
|
||||
|
||||
def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> {
|
||||
let summary = "QDQ: Fuse recognized QDQ op sequences.";
|
||||
let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()";
|
||||
let description = [{
|
||||
Torch models often represents quantized operations as the sequence:
|
||||
Dequantize
|
||||
DenseOp
|
||||
Quantize
|
||||
This allows the existing dense operations to be used without specifically
|
||||
representing quantized types. It is more computationally efficient to
|
||||
perform the dense operation in the quantized domain, so we fuse the
|
||||
quantization / dequantization behavior together and represent as purely
|
||||
quantized operations.
|
||||
}];
|
||||
}
|
||||
|
||||
def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> {
|
||||
let summary = "Match quantized operations that occur in different namespace.";
|
||||
let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()";
|
||||
let description = [{
|
||||
Torch quantization utilities generated custom op versions of known aten
|
||||
quantziation operations. We can match these specially named operations and
|
||||
rewrite to the corresponding aten quantized operations.
|
||||
|
||||
We handle this post import to maintain a simplified import process.
|
||||
}];
|
||||
}
|
||||
|
||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||
let summary = "Reify shape calculations.";
|
||||
let constructor = [{
|
||||
|
|
|
@ -29,6 +29,13 @@ using namespace mlir::torch;
|
|||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
static void getZeroPoint(Value value, Value &zeropoint) {
|
||||
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
zeropoint = make.getZeroPoint();
|
||||
}
|
||||
}
|
||||
|
||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
@ -64,11 +71,27 @@ public:
|
|||
op.getSelf().getType().cast<ValueTensorType>();
|
||||
ValueTensorType rhsTorchType =
|
||||
op.getMat2().getType().cast<ValueTensorType>();
|
||||
|
||||
Value lhsZeroPoint, rhsZeroPoint;
|
||||
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
||||
getZeroPoint(op.getMat2(), rhsZeroPoint);
|
||||
|
||||
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: aten.mm with mixed quantization");
|
||||
}
|
||||
|
||||
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: aten.mm with different input element types");
|
||||
}
|
||||
|
||||
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
|
||||
if (lhsZeroPoint && isUnsigned) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: unsigned quantized matmul not supported");
|
||||
}
|
||||
|
||||
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
||||
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
||||
|
||||
|
@ -89,8 +112,26 @@ public:
|
|||
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
|
||||
Value matmul;
|
||||
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
|
||||
if (intType && intType.isUnsigned()) {
|
||||
if (lhsZeroPoint && !isUnsigned) {
|
||||
lhsZeroPoint = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc,
|
||||
getTypeConverter()->convertType(lhsZeroPoint.getType()),
|
||||
lhsZeroPoint);
|
||||
rhsZeroPoint = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc,
|
||||
getTypeConverter()->convertType(rhsZeroPoint.getType()),
|
||||
rhsZeroPoint);
|
||||
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
||||
loc, rewriter.getI32Type(), lhsZeroPoint);
|
||||
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
||||
loc, rewriter.getI32Type(), rhsZeroPoint);
|
||||
matmul =
|
||||
rewriter
|
||||
.create<linalg::QuantizedMatmulOp>(
|
||||
loc, zeroFill.getType(),
|
||||
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
|
||||
.getResult(0);
|
||||
} else if (isUnsigned) {
|
||||
matmul = rewriter
|
||||
.create<linalg::MatmulUnsignedOp>(
|
||||
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
||||
|
|
|
@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses
|
|||
DecomposeComplexOps.cpp
|
||||
DropAbstractInterpCalculations.cpp
|
||||
EraseModuleInitializer.cpp
|
||||
FuseQuantizedOps.cpp
|
||||
Passes.cpp
|
||||
GlobalizeObjectGraph.cpp
|
||||
InlineGlobalSlots.cpp
|
||||
LowerToBackendContract.cpp
|
||||
MatchQuantizedOps.cpp
|
||||
MaximizeValueSemantics.cpp
|
||||
PrepareForGlobalizeObjectGraph.cpp
|
||||
RecomposeComplexOps.cpp
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename SrcOp>
|
||||
class QuantizeOperands : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::SmallVector<Value> operands(op->getOperands());
|
||||
|
||||
bool dequanted = false;
|
||||
for (auto &operand : operands) {
|
||||
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
|
||||
operand = dequant.getOperand();
|
||||
dequanted = true;
|
||||
}
|
||||
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
|
||||
operand = dequant.getOperand();
|
||||
dequanted = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!dequanted) {
|
||||
return rewriter.notifyMatchFailure(op, "no dequantizations found");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::SmallVector<Value> operands(op->getOperands());
|
||||
if (operands.size() < 3)
|
||||
return failure();
|
||||
|
||||
Value bias = operands[2];
|
||||
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
|
||||
return failure();
|
||||
|
||||
Value lhsScale;
|
||||
if (auto qLhs =
|
||||
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
|
||||
lhsScale = qLhs.getScale();
|
||||
|
||||
Value rhsScale;
|
||||
if (auto qRhs =
|
||||
operands[1].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
|
||||
rhsScale = qRhs.getScale();
|
||||
|
||||
if (!rhsScale || !lhsScale)
|
||||
return failure();
|
||||
|
||||
auto biasTy = bias.getType().cast<ValueTensorType>();
|
||||
auto biasETy = biasTy.getOptionalDtype();
|
||||
if (!biasETy || !isa<mlir::FloatType>(biasETy))
|
||||
return failure();
|
||||
|
||||
Value biasScale = rewriter.create<AtenMulFloatOp>(
|
||||
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
auto qi32Ty = rewriter.getType<QInt32Type>();
|
||||
auto newBiasTy =
|
||||
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
|
||||
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
|
||||
bias = rewriter.create<AtenQuantizePerTensorOp>(
|
||||
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
|
||||
|
||||
operands[2] = bias;
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp>
|
||||
class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.getOperand(0);
|
||||
auto rhs = op.getOperand(1);
|
||||
|
||||
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
|
||||
if (!resultTy || !resultTy.hasDtype())
|
||||
return failure();
|
||||
|
||||
Type resultETy = resultTy.getDtype();
|
||||
if (!resultETy.isa<mlir::FloatType>())
|
||||
return failure();
|
||||
|
||||
Value lhsScale;
|
||||
if (auto defining =
|
||||
lhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
lhsScale = defining.getScale();
|
||||
}
|
||||
|
||||
Value rhsScale;
|
||||
if (auto defining =
|
||||
rhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
rhsScale = defining.getScale();
|
||||
}
|
||||
|
||||
if (!lhsScale || !rhsScale)
|
||||
return failure();
|
||||
|
||||
// Quantize the bias input to the expected result:
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
op.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
auto qi32Ty = rewriter.getType<QInt32Type>();
|
||||
Value biasScale = rewriter.create<AtenMulFloatOp>(
|
||||
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
|
||||
|
||||
// Update the quantied type:
|
||||
llvm::SmallVector<Value> operands(op.getOperands());
|
||||
|
||||
auto newResultTy =
|
||||
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
|
||||
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
|
||||
|
||||
// Attach the quantize information to the resulting quint32:
|
||||
auto intReprTy = rewriter.getType<ValueTensorType>(
|
||||
resultTy.getOptionalSizes(),
|
||||
rewriter.getIntegerType(32, IntegerType::Signed));
|
||||
auto intRepr = rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, conv);
|
||||
|
||||
auto quantTy =
|
||||
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
|
||||
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||
op.getLoc(), quantTy, intRepr, biasScale, zero);
|
||||
auto dequant =
|
||||
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
|
||||
rewriter.replaceOp(op, dequant);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto result = op.getResult();
|
||||
if (result.use_empty()) {
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns
|
||||
.insert<RemoveUnused<AtenDequantizeSelfOp>,
|
||||
RemoveUnused<AtenDequantizeTensorOp>,
|
||||
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
|
||||
QuantizeAccumulator<AtenConvolutionOp>,
|
||||
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||
context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createFuseQuantizedOpsPass() {
|
||||
return std::make_unique<FuseQuantizedOpsPass>();
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
// Also available under a BSD-style license. See LICENSE.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
namespace {
|
||||
|
||||
Type getQuantizedType(MLIRContext *context, Type t) {
|
||||
if (t.isSignlessInteger(8))
|
||||
return Torch::QUInt8Type::get(context);
|
||||
if (t.isInteger(8) || t.isSignedInteger(8))
|
||||
return Torch::QInt8Type::get(context);
|
||||
if (t.isInteger(32))
|
||||
return Torch::QInt32Type::get(context);
|
||||
return {};
|
||||
}
|
||||
|
||||
class MatchQuantizeOperator : public OpRewritePattern<OperatorOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(OperatorOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.getName() == "torch.quantized_decomposed.quantize_per_tensor") {
|
||||
auto resultTy = cast<ValueTensorType>(op.getType(0));
|
||||
auto qeTy = getQuantizedType(rewriter.getContext(), resultTy.getDtype());
|
||||
if (!qeTy)
|
||||
qeTy = resultTy.getDtype();
|
||||
|
||||
auto qTy =
|
||||
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qeTy);
|
||||
Value quant = rewriter.create<AtenQuantizePerTensorOp>(
|
||||
op.getLoc(), qTy,
|
||||
/*self=*/op.getOperand(0), /*scale=*/op.getOperand(1),
|
||||
/*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5));
|
||||
|
||||
if (qTy != resultTy) {
|
||||
quant = rewriter.create<AtenIntReprOp>(op.getLoc(), resultTy, quant);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenClampOp>(
|
||||
op, resultTy, quant, op.getOperand(3), op.getOperand(4));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
|
||||
auto clamp = rewriter.create<AtenClampOp>(
|
||||
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
|
||||
op.getOperand(3), op.getOperand(4));
|
||||
|
||||
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>();
|
||||
if (!clampTy.hasDtype())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dequantization has unknown dtype");
|
||||
|
||||
Type dtype = clampTy.getDtype();
|
||||
Type qetype = getQuantizedType(op.getContext(), dtype);
|
||||
if (!qetype)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dequantization has unknown qtype");
|
||||
|
||||
Type qTy = Torch::ValueTensorType::get(
|
||||
op.getContext(), clampTy.getOptionalSizes(), qetype);
|
||||
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2));
|
||||
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(
|
||||
op, op.getResultTypes(), quant);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
class MatchQuantizedCustomOpsPass
|
||||
: public MatchQuantizedCustomOpsBase<MatchQuantizedCustomOpsPass> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<MatchQuantizeOperator>(context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
mlir::torch::Torch::createMatchQuantizedCustomOpsPass() {
|
||||
return std::make_unique<MatchQuantizedCustomOpsPass>();
|
||||
}
|
|
@ -15,12 +15,13 @@
|
|||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
#endif
|
||||
|
@ -64,6 +65,9 @@ void mlir::torch::registerTorchConversionPasses() {
|
|||
|
||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||
OpPassManager &pm) {
|
||||
// We want to fuse quantized operations together before lowering to linalg.
|
||||
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
||||
|
||||
// Lower to linalg + guards which is the input to codegen backends.
|
||||
// We do this first as it tends to involve pattern-matching against constants,
|
||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||
|
|
|
@ -39,6 +39,38 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
|
||||
const at::Tensor &self, double scale, int64_t zero_point) {
|
||||
if (self.scalar_type() == at::kChar)
|
||||
return {Shape(at::kQInt8, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kByte)
|
||||
return {Shape(at::kQUInt8, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kInt)
|
||||
return {Shape(at::kQInt32, self.sizes().vec())};
|
||||
assert(false);
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_int_repr(const at::Tensor &self) {
|
||||
if (self.scalar_type() == at::kQInt8)
|
||||
return {Shape(at::kChar, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kQUInt8)
|
||||
return {Shape(at::kByte, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kQInt32)
|
||||
return {Shape(at::kInt, self.sizes().vec())};
|
||||
assert(false);
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape_dequantize(const at::Tensor &self) {
|
||||
return {Shape(at::kFloat, self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape_quantize_per_tensor(const at::Tensor &self, double scale,
|
||||
int64_t zero_point, at::ScalarType dtype) {
|
||||
return {Shape(dtype, self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
||||
return {Shape(at::kBool, self.sizes().vec())};
|
||||
}
|
||||
|
@ -102,6 +134,12 @@ std::vector<torch::lazy::Shape> compute_shape_var(
|
|||
return {Shape(self.scalar_type(), {})};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_nan_to_num(
|
||||
const at::Tensor & self, c10::optional<double> nan,
|
||||
c10::optional<double> posinf, c10::optional<double> neginf) {
|
||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
|
||||
const at::Tensor& self, const at::Scalar& min_val,
|
||||
const at::Scalar& max_val) {
|
||||
|
|
|
@ -315,6 +315,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# Dynamo does not support tracing quantized tensors
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
|
||||
# Dynamo not supporting conv_tbc
|
||||
"ConvTbcModule_basic",
|
||||
|
@ -1539,7 +1540,4 @@ LTC_XFAIL_SET = {
|
|||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||
"ElementwiseNanToNumModule_Basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic"
|
||||
}
|
||||
|
|
|
@ -262,3 +262,30 @@ class AtenMmIntTypes(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: AtenMmIntTypes())
|
||||
def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMmQuint8(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int8, True),
|
||||
([4, 3], torch.int8, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||
def AtenMmQuint8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// RUN: torch-mlir-opt %s --split-input-file --torch-fuse-quantized-ops | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @mm
|
||||
func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si8>) -> !torch.vtensor<[4, 4],f32> {
|
||||
%scale = torch.constant.float 0.5
|
||||
%false = torch.constant.bool false
|
||||
%zero = torch.constant.int 0
|
||||
%one = torch.constant.int 1
|
||||
%zp = torch.constant.int -128
|
||||
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
|
||||
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
|
||||
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
|
||||
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
|
||||
%16 = torch.aten.mm %7, %13 : !torch.vtensor<[4, 4],f32>, !torch.vtensor<[4, 4],f32> -> !torch.vtensor<[4, 4],f32>
|
||||
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01
|
||||
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF:.+]], %[[ONE]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
|
||||
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF:.+]], %[[ZERO]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
|
||||
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %[[QLHS]], %[[QRHS]] : !torch.vtensor<[4,4],!torch.qint8>, !torch.vtensor<[4,4],!torch.qint8> -> !torch.vtensor<[4,4],!torch.qint32>
|
||||
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],si32>
|
||||
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[QUARTER]], %[[ZERO]] : !torch.vtensor<[4,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint32>
|
||||
// CHECK: %[[OUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],f32>
|
||||
return %16 : !torch.vtensor<[4, 4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convolution
|
||||
func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||
%scale = torch.constant.float 0.5
|
||||
%false = torch.constant.bool false
|
||||
%zero = torch.constant.int 0
|
||||
%one = torch.constant.int 1
|
||||
%zp = torch.constant.int -128
|
||||
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||||
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32>
|
||||
%14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>
|
||||
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
|
||||
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14
|
||||
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||
// CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
|
||||
// CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
|
||||
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32>
|
||||
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
|
||||
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
|
||||
return %16 : !torch.vtensor<[1,3,7,7],f32>
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @quantize_per_tensor
|
||||
func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> {
|
||||
%float = torch.constant.float 0.5
|
||||
%zp = torch.constant.int 17
|
||||
%min = torch.constant.int -128
|
||||
%max = torch.constant.int 127
|
||||
%dtype = torch.constant.int 1
|
||||
|
||||
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||||
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||||
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||||
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
// CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8>
|
||||
// CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]]
|
||||
%0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8>
|
||||
return %0 : !torch.vtensor<[1,3,8,8],si8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @dequantize_per_tensor
|
||||
func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> {
|
||||
%float = torch.constant.float 0.5
|
||||
%zp = torch.constant.int 17
|
||||
%min = torch.constant.int -128
|
||||
%max = torch.constant.int 127
|
||||
%dtype = torch.constant.int 1
|
||||
|
||||
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||||
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||||
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||||
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8>
|
||||
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||||
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
|
||||
return %13 : !torch.vtensor<[1,3,8,8],f32>
|
||||
}
|
Loading…
Reference in New Issue