mirror of https://github.com/llvm/torch-mlir
[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)
This includes custom op matching for decomposed operations and fusing dequantization into dense operations. As a validation we compare to the dequant+mm torch implementation.pull/2803/head
parent
60bf6c25af
commit
f6f890520b
|
@ -1 +1 @@
|
||||||
Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b
|
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0
|
|
@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
createMatchQuantizedCustomOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
createReifyShapeCalculationsPass(StringRef extraLibrary);
|
||||||
|
|
||||||
|
|
|
@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> {
|
||||||
|
let summary = "QDQ: Fuse recognized QDQ op sequences.";
|
||||||
|
let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()";
|
||||||
|
let description = [{
|
||||||
|
Torch models often represents quantized operations as the sequence:
|
||||||
|
Dequantize
|
||||||
|
DenseOp
|
||||||
|
Quantize
|
||||||
|
This allows the existing dense operations to be used without specifically
|
||||||
|
representing quantized types. It is more computationally efficient to
|
||||||
|
perform the dense operation in the quantized domain, so we fuse the
|
||||||
|
quantization / dequantization behavior together and represent as purely
|
||||||
|
quantized operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> {
|
||||||
|
let summary = "Match quantized operations that occur in different namespace.";
|
||||||
|
let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()";
|
||||||
|
let description = [{
|
||||||
|
Torch quantization utilities generated custom op versions of known aten
|
||||||
|
quantziation operations. We can match these specially named operations and
|
||||||
|
rewrite to the corresponding aten quantized operations.
|
||||||
|
|
||||||
|
We handle this post import to maintain a simplified import process.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
|
||||||
let summary = "Reify shape calculations.";
|
let summary = "Reify shape calculations.";
|
||||||
let constructor = [{
|
let constructor = [{
|
||||||
|
|
|
@ -29,6 +29,13 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static void getZeroPoint(Value value, Value &zeropoint) {
|
||||||
|
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||||
|
zeropoint = make.getZeroPoint();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
@ -64,11 +71,27 @@ public:
|
||||||
op.getSelf().getType().cast<ValueTensorType>();
|
op.getSelf().getType().cast<ValueTensorType>();
|
||||||
ValueTensorType rhsTorchType =
|
ValueTensorType rhsTorchType =
|
||||||
op.getMat2().getType().cast<ValueTensorType>();
|
op.getMat2().getType().cast<ValueTensorType>();
|
||||||
|
|
||||||
|
Value lhsZeroPoint, rhsZeroPoint;
|
||||||
|
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
||||||
|
getZeroPoint(op.getMat2(), rhsZeroPoint);
|
||||||
|
|
||||||
|
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: aten.mm with mixed quantization");
|
||||||
|
}
|
||||||
|
|
||||||
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unsupported: aten.mm with different input element types");
|
op, "unsupported: aten.mm with different input element types");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
|
||||||
|
if (lhsZeroPoint && isUnsigned) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: unsigned quantized matmul not supported");
|
||||||
|
}
|
||||||
|
|
||||||
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
|
||||||
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
|
||||||
|
|
||||||
|
@ -89,8 +112,26 @@ public:
|
||||||
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||||
|
|
||||||
Value matmul;
|
Value matmul;
|
||||||
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
|
if (lhsZeroPoint && !isUnsigned) {
|
||||||
if (intType && intType.isUnsigned()) {
|
lhsZeroPoint = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc,
|
||||||
|
getTypeConverter()->convertType(lhsZeroPoint.getType()),
|
||||||
|
lhsZeroPoint);
|
||||||
|
rhsZeroPoint = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc,
|
||||||
|
getTypeConverter()->convertType(rhsZeroPoint.getType()),
|
||||||
|
rhsZeroPoint);
|
||||||
|
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
||||||
|
loc, rewriter.getI32Type(), lhsZeroPoint);
|
||||||
|
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
|
||||||
|
loc, rewriter.getI32Type(), rhsZeroPoint);
|
||||||
|
matmul =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::QuantizedMatmulOp>(
|
||||||
|
loc, zeroFill.getType(),
|
||||||
|
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
|
||||||
|
.getResult(0);
|
||||||
|
} else if (isUnsigned) {
|
||||||
matmul = rewriter
|
matmul = rewriter
|
||||||
.create<linalg::MatmulUnsignedOp>(
|
.create<linalg::MatmulUnsignedOp>(
|
||||||
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)
|
||||||
|
|
|
@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses
|
||||||
DecomposeComplexOps.cpp
|
DecomposeComplexOps.cpp
|
||||||
DropAbstractInterpCalculations.cpp
|
DropAbstractInterpCalculations.cpp
|
||||||
EraseModuleInitializer.cpp
|
EraseModuleInitializer.cpp
|
||||||
|
FuseQuantizedOps.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
GlobalizeObjectGraph.cpp
|
GlobalizeObjectGraph.cpp
|
||||||
InlineGlobalSlots.cpp
|
InlineGlobalSlots.cpp
|
||||||
LowerToBackendContract.cpp
|
LowerToBackendContract.cpp
|
||||||
|
MatchQuantizedOps.cpp
|
||||||
MaximizeValueSemantics.cpp
|
MaximizeValueSemantics.cpp
|
||||||
PrepareForGlobalizeObjectGraph.cpp
|
PrepareForGlobalizeObjectGraph.cpp
|
||||||
RecomposeComplexOps.cpp
|
RecomposeComplexOps.cpp
|
||||||
|
|
|
@ -0,0 +1,214 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename SrcOp>
|
||||||
|
class QuantizeOperands : public OpRewritePattern<SrcOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(SrcOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
llvm::SmallVector<Value> operands(op->getOperands());
|
||||||
|
|
||||||
|
bool dequanted = false;
|
||||||
|
for (auto &operand : operands) {
|
||||||
|
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
|
||||||
|
operand = dequant.getOperand();
|
||||||
|
dequanted = true;
|
||||||
|
}
|
||||||
|
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
|
||||||
|
operand = dequant.getOperand();
|
||||||
|
dequanted = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dequanted) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "no dequantizations found");
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(SrcOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
llvm::SmallVector<Value> operands(op->getOperands());
|
||||||
|
if (operands.size() < 3)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value bias = operands[2];
|
||||||
|
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value lhsScale;
|
||||||
|
if (auto qLhs =
|
||||||
|
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
|
||||||
|
lhsScale = qLhs.getScale();
|
||||||
|
|
||||||
|
Value rhsScale;
|
||||||
|
if (auto qRhs =
|
||||||
|
operands[1].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
|
||||||
|
rhsScale = qRhs.getScale();
|
||||||
|
|
||||||
|
if (!rhsScale || !lhsScale)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto biasTy = bias.getType().cast<ValueTensorType>();
|
||||||
|
auto biasETy = biasTy.getOptionalDtype();
|
||||||
|
if (!biasETy || !isa<mlir::FloatType>(biasETy))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value biasScale = rewriter.create<AtenMulFloatOp>(
|
||||||
|
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
auto qi32Ty = rewriter.getType<QInt32Type>();
|
||||||
|
auto newBiasTy =
|
||||||
|
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
|
||||||
|
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
|
||||||
|
bias = rewriter.create<AtenQuantizePerTensorOp>(
|
||||||
|
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
|
||||||
|
|
||||||
|
operands[2] = bias;
|
||||||
|
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcOp>
|
||||||
|
class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(SrcOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto lhs = op.getOperand(0);
|
||||||
|
auto rhs = op.getOperand(1);
|
||||||
|
|
||||||
|
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
|
||||||
|
if (!resultTy || !resultTy.hasDtype())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Type resultETy = resultTy.getDtype();
|
||||||
|
if (!resultETy.isa<mlir::FloatType>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value lhsScale;
|
||||||
|
if (auto defining =
|
||||||
|
lhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||||
|
lhsScale = defining.getScale();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value rhsScale;
|
||||||
|
if (auto defining =
|
||||||
|
rhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||||
|
rhsScale = defining.getScale();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!lhsScale || !rhsScale)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Quantize the bias input to the expected result:
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
auto qi32Ty = rewriter.getType<QInt32Type>();
|
||||||
|
Value biasScale = rewriter.create<AtenMulFloatOp>(
|
||||||
|
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
|
||||||
|
|
||||||
|
// Update the quantied type:
|
||||||
|
llvm::SmallVector<Value> operands(op.getOperands());
|
||||||
|
|
||||||
|
auto newResultTy =
|
||||||
|
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
|
||||||
|
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
|
||||||
|
|
||||||
|
// Attach the quantize information to the resulting quint32:
|
||||||
|
auto intReprTy = rewriter.getType<ValueTensorType>(
|
||||||
|
resultTy.getOptionalSizes(),
|
||||||
|
rewriter.getIntegerType(32, IntegerType::Signed));
|
||||||
|
auto intRepr = rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, conv);
|
||||||
|
|
||||||
|
auto quantTy =
|
||||||
|
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
|
||||||
|
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||||
|
op.getLoc(), quantTy, intRepr, biasScale, zero);
|
||||||
|
auto dequant =
|
||||||
|
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
|
||||||
|
rewriter.replaceOp(op, dequant);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(SrcOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto result = op.getResult();
|
||||||
|
if (result.use_empty()) {
|
||||||
|
op.erase();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
|
||||||
|
public:
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns
|
||||||
|
.insert<RemoveUnused<AtenDequantizeSelfOp>,
|
||||||
|
RemoveUnused<AtenDequantizeTensorOp>,
|
||||||
|
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||||
|
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
|
||||||
|
QuantizeAccumulator<AtenConvolutionOp>,
|
||||||
|
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||||
|
context);
|
||||||
|
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
|
config))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::Torch::createFuseQuantizedOpsPass() {
|
||||||
|
return std::make_unique<FuseQuantizedOpsPass>();
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Type getQuantizedType(MLIRContext *context, Type t) {
|
||||||
|
if (t.isSignlessInteger(8))
|
||||||
|
return Torch::QUInt8Type::get(context);
|
||||||
|
if (t.isInteger(8) || t.isSignedInteger(8))
|
||||||
|
return Torch::QInt8Type::get(context);
|
||||||
|
if (t.isInteger(32))
|
||||||
|
return Torch::QInt32Type::get(context);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
class MatchQuantizeOperator : public OpRewritePattern<OperatorOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(OperatorOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (op.getName() == "torch.quantized_decomposed.quantize_per_tensor") {
|
||||||
|
auto resultTy = cast<ValueTensorType>(op.getType(0));
|
||||||
|
auto qeTy = getQuantizedType(rewriter.getContext(), resultTy.getDtype());
|
||||||
|
if (!qeTy)
|
||||||
|
qeTy = resultTy.getDtype();
|
||||||
|
|
||||||
|
auto qTy =
|
||||||
|
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qeTy);
|
||||||
|
Value quant = rewriter.create<AtenQuantizePerTensorOp>(
|
||||||
|
op.getLoc(), qTy,
|
||||||
|
/*self=*/op.getOperand(0), /*scale=*/op.getOperand(1),
|
||||||
|
/*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5));
|
||||||
|
|
||||||
|
if (qTy != resultTy) {
|
||||||
|
quant = rewriter.create<AtenIntReprOp>(op.getLoc(), resultTy, quant);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<AtenClampOp>(
|
||||||
|
op, resultTy, quant, op.getOperand(3), op.getOperand(4));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
|
||||||
|
auto clamp = rewriter.create<AtenClampOp>(
|
||||||
|
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
|
||||||
|
op.getOperand(3), op.getOperand(4));
|
||||||
|
|
||||||
|
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (!clampTy.hasDtype())
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"dequantization has unknown dtype");
|
||||||
|
|
||||||
|
Type dtype = clampTy.getDtype();
|
||||||
|
Type qetype = getQuantizedType(op.getContext(), dtype);
|
||||||
|
if (!qetype)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"dequantization has unknown qtype");
|
||||||
|
|
||||||
|
Type qTy = Torch::ValueTensorType::get(
|
||||||
|
op.getContext(), clampTy.getOptionalSizes(), qetype);
|
||||||
|
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||||
|
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2));
|
||||||
|
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(
|
||||||
|
op, op.getResultTypes(), quant);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatchQuantizedCustomOpsPass
|
||||||
|
: public MatchQuantizedCustomOpsBase<MatchQuantizedCustomOpsPass> {
|
||||||
|
public:
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.insert<MatchQuantizeOperator>(context);
|
||||||
|
|
||||||
|
GreedyRewriteConfig config;
|
||||||
|
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||||
|
config)))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::Torch::createMatchQuantizedCustomOpsPass() {
|
||||||
|
return std::make_unique<MatchQuantizedCustomOpsPass>();
|
||||||
|
}
|
|
@ -15,12 +15,13 @@
|
||||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
|
||||||
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
|
|
||||||
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
|
||||||
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
|
||||||
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
#endif
|
#endif
|
||||||
|
@ -64,6 +65,9 @@ void mlir::torch::registerTorchConversionPasses() {
|
||||||
|
|
||||||
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
|
||||||
OpPassManager &pm) {
|
OpPassManager &pm) {
|
||||||
|
// We want to fuse quantized operations together before lowering to linalg.
|
||||||
|
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
|
||||||
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
// We do this first as it tends to involve pattern-matching against constants,
|
// We do this first as it tends to involve pattern-matching against constants,
|
||||||
// (e.g. dimensions which must be constant in a ranked programming model)
|
// (e.g. dimensions which must be constant in a ranked programming model)
|
||||||
|
|
|
@ -39,6 +39,38 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
|
||||||
|
const at::Tensor &self, double scale, int64_t zero_point) {
|
||||||
|
if (self.scalar_type() == at::kChar)
|
||||||
|
return {Shape(at::kQInt8, self.sizes().vec())};
|
||||||
|
if (self.scalar_type() == at::kByte)
|
||||||
|
return {Shape(at::kQUInt8, self.sizes().vec())};
|
||||||
|
if (self.scalar_type() == at::kInt)
|
||||||
|
return {Shape(at::kQInt32, self.sizes().vec())};
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_int_repr(const at::Tensor &self) {
|
||||||
|
if (self.scalar_type() == at::kQInt8)
|
||||||
|
return {Shape(at::kChar, self.sizes().vec())};
|
||||||
|
if (self.scalar_type() == at::kQUInt8)
|
||||||
|
return {Shape(at::kByte, self.sizes().vec())};
|
||||||
|
if (self.scalar_type() == at::kQInt32)
|
||||||
|
return {Shape(at::kInt, self.sizes().vec())};
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape>
|
||||||
|
compute_shape_dequantize(const at::Tensor &self) {
|
||||||
|
return {Shape(at::kFloat, self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape>
|
||||||
|
compute_shape_quantize_per_tensor(const at::Tensor &self, double scale,
|
||||||
|
int64_t zero_point, at::ScalarType dtype) {
|
||||||
|
return {Shape(dtype, self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
||||||
return {Shape(at::kBool, self.sizes().vec())};
|
return {Shape(at::kBool, self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
@ -102,6 +134,12 @@ std::vector<torch::lazy::Shape> compute_shape_var(
|
||||||
return {Shape(self.scalar_type(), {})};
|
return {Shape(self.scalar_type(), {})};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_nan_to_num(
|
||||||
|
const at::Tensor & self, c10::optional<double> nan,
|
||||||
|
c10::optional<double> posinf, c10::optional<double> neginf) {
|
||||||
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
|
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
|
||||||
const at::Tensor& self, const at::Scalar& min_val,
|
const at::Tensor& self, const at::Scalar& min_val,
|
||||||
const at::Scalar& max_val) {
|
const at::Scalar& max_val) {
|
||||||
|
|
|
@ -315,6 +315,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# Dynamo does not support tracing quantized tensors
|
# Dynamo does not support tracing quantized tensors
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
|
"AtenMmQuint8_basic",
|
||||||
|
|
||||||
# Dynamo not supporting conv_tbc
|
# Dynamo not supporting conv_tbc
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
|
@ -1539,7 +1540,4 @@ LTC_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
"ElementwiseNanToNumModule_Basic",
|
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
|
||||||
"ElementwiseDequantizePerTensorModule_basic"
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -262,3 +262,30 @@ class AtenMmIntTypes(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: AtenMmIntTypes())
|
@register_test_case(module_factory=lambda: AtenMmIntTypes())
|
||||||
def AtenMmIntTypes_basic(module, tu: TestUtils):
|
def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
|
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenMmQuint8(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4], torch.int8, True),
|
||||||
|
([4, 3], torch.int8, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8)
|
||||||
|
qx = torch.dequantize(qx)
|
||||||
|
qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8)
|
||||||
|
qy = torch.dequantize(qy)
|
||||||
|
qz = torch.mm(qx, qy)
|
||||||
|
return qz
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||||
|
def AtenMmQuint8_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||||
|
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
// RUN: torch-mlir-opt %s --split-input-file --torch-fuse-quantized-ops | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: @mm
|
||||||
|
func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si8>) -> !torch.vtensor<[4, 4],f32> {
|
||||||
|
%scale = torch.constant.float 0.5
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%zero = torch.constant.int 0
|
||||||
|
%one = torch.constant.int 1
|
||||||
|
%zp = torch.constant.int -128
|
||||||
|
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
|
||||||
|
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
|
||||||
|
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
|
||||||
|
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
|
||||||
|
%16 = torch.aten.mm %7, %13 : !torch.vtensor<[4, 4],f32>, !torch.vtensor<[4, 4],f32> -> !torch.vtensor<[4, 4],f32>
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01
|
||||||
|
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF:.+]], %[[ONE]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
|
||||||
|
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF:.+]], %[[ZERO]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
|
||||||
|
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %[[QLHS]], %[[QRHS]] : !torch.vtensor<[4,4],!torch.qint8>, !torch.vtensor<[4,4],!torch.qint8> -> !torch.vtensor<[4,4],!torch.qint32>
|
||||||
|
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],si32>
|
||||||
|
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[QUARTER]], %[[ZERO]] : !torch.vtensor<[4,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint32>
|
||||||
|
// CHECK: %[[OUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],f32>
|
||||||
|
return %16 : !torch.vtensor<[4, 4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @convolution
|
||||||
|
func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||||
|
%scale = torch.constant.float 0.5
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%zero = torch.constant.int 0
|
||||||
|
%one = torch.constant.int 1
|
||||||
|
%zp = torch.constant.int -128
|
||||||
|
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||||
|
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||||||
|
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||||
|
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32>
|
||||||
|
%14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
|
||||||
|
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14
|
||||||
|
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||||
|
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||||
|
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||||
|
// CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
|
||||||
|
// CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
|
||||||
|
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32>
|
||||||
|
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
|
||||||
|
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
|
||||||
|
return %16 : !torch.vtensor<[1,3,7,7],f32>
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @quantize_per_tensor
|
||||||
|
func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> {
|
||||||
|
%float = torch.constant.float 0.5
|
||||||
|
%zp = torch.constant.int 17
|
||||||
|
%min = torch.constant.int -128
|
||||||
|
%max = torch.constant.int 127
|
||||||
|
%dtype = torch.constant.int 1
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||||||
|
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||||||
|
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||||||
|
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||||
|
// CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8>
|
||||||
|
// CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]]
|
||||||
|
%0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8>
|
||||||
|
return %0 : !torch.vtensor<[1,3,8,8],si8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @dequantize_per_tensor
|
||||||
|
func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> {
|
||||||
|
%float = torch.constant.float 0.5
|
||||||
|
%zp = torch.constant.int 17
|
||||||
|
%min = torch.constant.int -128
|
||||||
|
%max = torch.constant.int 127
|
||||||
|
%dtype = torch.constant.int 1
|
||||||
|
|
||||||
|
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
|
||||||
|
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
|
||||||
|
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
|
||||||
|
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8>
|
||||||
|
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||||
|
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||||||
|
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
|
||||||
|
return %13 : !torch.vtensor<[1,3,8,8],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue