[torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750)

This includes custom op matching for decomposed operations and fusing
dequantization into dense operations. As a validation we compare
to the dequant+mm torch implementation.
pull/2803/head
Rob Suderman 2024-01-24 14:02:50 -08:00 committed by GitHub
parent 60bf6c25af
commit f6f890520b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 577 additions and 8 deletions

@ -1 +1 @@
Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b
Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0

View File

@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createFuseQuantizedOpsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createMatchQuantizedCustomOpsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary);

View File

@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
}];
}
def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> {
let summary = "QDQ: Fuse recognized QDQ op sequences.";
let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()";
let description = [{
Torch models often represents quantized operations as the sequence:
Dequantize
DenseOp
Quantize
This allows the existing dense operations to be used without specifically
representing quantized types. It is more computationally efficient to
perform the dense operation in the quantized domain, so we fuse the
quantization / dequantization behavior together and represent as purely
quantized operations.
}];
}
def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> {
let summary = "Match quantized operations that occur in different namespace.";
let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()";
let description = [{
Torch quantization utilities generated custom op versions of known aten
quantziation operations. We can match these specially named operations and
rewrite to the corresponding aten quantized operations.
We handle this post import to maintain a simplified import process.
}];
}
def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations.";
let constructor = [{

View File

@ -29,6 +29,13 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
static void getZeroPoint(Value value, Value &zeropoint) {
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
zeropoint = make.getZeroPoint();
}
}
class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -64,11 +71,27 @@ public:
op.getSelf().getType().cast<ValueTensorType>();
ValueTensorType rhsTorchType =
op.getMat2().getType().cast<ValueTensorType>();
Value lhsZeroPoint, rhsZeroPoint;
getZeroPoint(op.getSelf(), lhsZeroPoint);
getZeroPoint(op.getMat2(), rhsZeroPoint);
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with mixed quantization");
}
if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
}
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
if (lhsZeroPoint && isUnsigned) {
return rewriter.notifyMatchFailure(
op, "unsupported: unsigned quantized matmul not supported");
}
Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
@ -89,8 +112,26 @@ public:
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value matmul;
auto intType = dyn_cast<mlir::IntegerType>(lhsTorchType.getDtype());
if (intType && intType.isUnsigned()) {
if (lhsZeroPoint && !isUnsigned) {
lhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(lhsZeroPoint.getType()),
lhsZeroPoint);
rhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(rhsZeroPoint.getType()),
rhsZeroPoint);
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), lhsZeroPoint);
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), rhsZeroPoint);
matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
loc, zeroFill.getType(),
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill)
.getResult(0);
} else if (isUnsigned) {
matmul = rewriter
.create<linalg::MatmulUnsignedOp>(
loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill)

View File

@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses
DecomposeComplexOps.cpp
DropAbstractInterpCalculations.cpp
EraseModuleInitializer.cpp
FuseQuantizedOps.cpp
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp
LowerToBackendContract.cpp
MatchQuantizedOps.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
RecomposeComplexOps.cpp

View File

@ -0,0 +1,214 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
bool dequanted = false;
for (auto &operand : operands) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
}
if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
}
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
if (operands.size() < 3)
return failure();
Value bias = operands[2];
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
return failure();
Value lhsScale;
if (auto qLhs =
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
lhsScale = qLhs.getScale();
Value rhsScale;
if (auto qRhs =
operands[1].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
rhsScale = qRhs.getScale();
if (!rhsScale || !lhsScale)
return failure();
auto biasTy = bias.getType().cast<ValueTensorType>();
auto biasETy = biasTy.getOptionalDtype();
if (!biasETy || !isa<mlir::FloatType>(biasETy))
return failure();
Value biasScale = rewriter.create<AtenMulFloatOp>(
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
Value zero = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto qi32Ty = rewriter.getType<QInt32Type>();
auto newBiasTy =
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
bias = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), newBiasTy, bias, biasScale, zero, dtype);
operands[2] = bias;
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};
template <typename SrcOp>
class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype())
return failure();
Type resultETy = resultTy.getDtype();
if (!resultETy.isa<mlir::FloatType>())
return failure();
Value lhsScale;
if (auto defining =
lhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
lhsScale = defining.getScale();
}
Value rhsScale;
if (auto defining =
rhs.template getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
rhsScale = defining.getScale();
}
if (!lhsScale || !rhsScale)
return failure();
// Quantize the bias input to the expected result:
Value zero = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto qi32Ty = rewriter.getType<QInt32Type>();
Value biasScale = rewriter.create<AtenMulFloatOp>(
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
// Update the quantied type:
llvm::SmallVector<Value> operands(op.getOperands());
auto newResultTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto conv = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
// Attach the quantize information to the resulting quint32:
auto intReprTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(),
rewriter.getIntegerType(32, IntegerType::Signed));
auto intRepr = rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, conv);
auto quantTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qi32Ty);
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), quantTy, intRepr, biasScale, zero);
auto dequant =
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
rewriter.replaceOp(op, dequant);
return success();
}
};
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
auto result = op.getResult();
if (result.use_empty()) {
op.erase();
return success();
}
return failure();
}
};
class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns
.insert<RemoveUnused<AtenDequantizeSelfOp>,
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
QuantizeAccumulator<AtenConvolutionOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);
GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createFuseQuantizedOpsPass() {
return std::make_unique<FuseQuantizedOpsPass>();
}

View File

@ -0,0 +1,109 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
Type getQuantizedType(MLIRContext *context, Type t) {
if (t.isSignlessInteger(8))
return Torch::QUInt8Type::get(context);
if (t.isInteger(8) || t.isSignedInteger(8))
return Torch::QInt8Type::get(context);
if (t.isInteger(32))
return Torch::QInt32Type::get(context);
return {};
}
class MatchQuantizeOperator : public OpRewritePattern<OperatorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OperatorOp op,
PatternRewriter &rewriter) const override {
if (op.getName() == "torch.quantized_decomposed.quantize_per_tensor") {
auto resultTy = cast<ValueTensorType>(op.getType(0));
auto qeTy = getQuantizedType(rewriter.getContext(), resultTy.getDtype());
if (!qeTy)
qeTy = resultTy.getDtype();
auto qTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qeTy);
Value quant = rewriter.create<AtenQuantizePerTensorOp>(
op.getLoc(), qTy,
/*self=*/op.getOperand(0), /*scale=*/op.getOperand(1),
/*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5));
if (qTy != resultTy) {
quant = rewriter.create<AtenIntReprOp>(op.getLoc(), resultTy, quant);
}
rewriter.replaceOpWithNewOp<AtenClampOp>(
op, resultTy, quant, op.getOperand(3), op.getOperand(4));
return success();
}
if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") {
auto clamp = rewriter.create<AtenClampOp>(
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
op.getOperand(3), op.getOperand(4));
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>();
if (!clampTy.hasDtype())
return rewriter.notifyMatchFailure(op,
"dequantization has unknown dtype");
Type dtype = clampTy.getDtype();
Type qetype = getQuantizedType(op.getContext(), dtype);
if (!qetype)
return rewriter.notifyMatchFailure(op,
"dequantization has unknown qtype");
Type qTy = Torch::ValueTensorType::get(
op.getContext(), clampTy.getOptionalSizes(), qetype);
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2));
rewriter.replaceOpWithNewOp<AtenDequantizeTensorOp>(
op, op.getResultTypes(), quant);
return success();
}
return failure();
}
};
class MatchQuantizedCustomOpsPass
: public MatchQuantizedCustomOpsBase<MatchQuantizedCustomOpsPass> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<MatchQuantizeOperator>(context);
GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config)))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createMatchQuantizedCustomOpsPass() {
return std::make_unique<MatchQuantizedCustomOpsPass>();
}

View File

@ -15,12 +15,13 @@
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
@ -64,6 +65,9 @@ void mlir::torch::registerTorchConversionPasses() {
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
OpPassManager &pm) {
// We want to fuse quantized operations together before lowering to linalg.
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
// Lower to linalg + guards which is the input to codegen backends.
// We do this first as it tends to involve pattern-matching against constants,
// (e.g. dimensions which must be constant in a ranked programming model)

View File

@ -39,6 +39,38 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
const at::Tensor &self, double scale, int64_t zero_point) {
if (self.scalar_type() == at::kChar)
return {Shape(at::kQInt8, self.sizes().vec())};
if (self.scalar_type() == at::kByte)
return {Shape(at::kQUInt8, self.sizes().vec())};
if (self.scalar_type() == at::kInt)
return {Shape(at::kQInt32, self.sizes().vec())};
assert(false);
}
std::vector<torch::lazy::Shape> compute_shape_int_repr(const at::Tensor &self) {
if (self.scalar_type() == at::kQInt8)
return {Shape(at::kChar, self.sizes().vec())};
if (self.scalar_type() == at::kQUInt8)
return {Shape(at::kByte, self.sizes().vec())};
if (self.scalar_type() == at::kQInt32)
return {Shape(at::kInt, self.sizes().vec())};
assert(false);
}
std::vector<torch::lazy::Shape>
compute_shape_dequantize(const at::Tensor &self) {
return {Shape(at::kFloat, self.sizes().vec())};
}
std::vector<torch::lazy::Shape>
compute_shape_quantize_per_tensor(const at::Tensor &self, double scale,
int64_t zero_point, at::ScalarType dtype) {
return {Shape(dtype, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}
@ -102,6 +134,12 @@ std::vector<torch::lazy::Shape> compute_shape_var(
return {Shape(self.scalar_type(), {})};
}
std::vector<torch::lazy::Shape> compute_shape_nan_to_num(
const at::Tensor & self, c10::optional<double> nan,
c10::optional<double> posinf, c10::optional<double> neginf) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
const at::Tensor& self, const at::Scalar& min_val,
const at::Scalar& max_val) {

View File

@ -315,6 +315,7 @@ TORCHDYNAMO_XFAIL_SET = {
# Dynamo does not support tracing quantized tensors
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"AtenMmQuint8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
@ -1539,7 +1540,4 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseDequantizePerTensorModule_basic"
}

View File

@ -262,3 +262,30 @@ class AtenMmIntTypes(torch.nn.Module):
@register_test_case(module_factory=lambda: AtenMmIntTypes())
def AtenMmIntTypes_basic(module, tu: TestUtils):
module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100))
# ==============================================================================
class AtenMmQuint8(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int8, True),
([4, 3], torch.int8, True),
])
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
@register_test_case(module_factory=lambda: AtenMmQuint8())
def AtenMmQuint8_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
tu.randint(4, 3, low=-128, high=127).to(torch.int8))

View File

@ -0,0 +1,62 @@
// RUN: torch-mlir-opt %s --split-input-file --torch-fuse-quantized-ops | FileCheck %s
// CHECK-LABEL: @mm
func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si8>) -> !torch.vtensor<[4, 4],f32> {
%scale = torch.constant.float 0.5
%false = torch.constant.bool false
%zero = torch.constant.int 0
%one = torch.constant.int 1
%zp = torch.constant.int -128
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8>
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32>
%16 = torch.aten.mm %7, %13 : !torch.vtensor<[4, 4],f32>, !torch.vtensor<[4, 4],f32> -> !torch.vtensor<[4, 4],f32>
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
// CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF:.+]], %[[ONE]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF:.+]], %[[ZERO]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8>
// CHECK-DAG: %[[MM:.+]] = torch.aten.mm %[[QLHS]], %[[QRHS]] : !torch.vtensor<[4,4],!torch.qint8>, !torch.vtensor<[4,4],!torch.qint8> -> !torch.vtensor<[4,4],!torch.qint32>
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],si32>
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[QUARTER]], %[[ZERO]] : !torch.vtensor<[4,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint32>
// CHECK: %[[OUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],f32>
return %16 : !torch.vtensor<[4, 4],f32>
}
// -----
// CHECK-LABEL: @convolution
func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
%scale = torch.constant.float 0.5
%false = torch.constant.bool false
%zero = torch.constant.int 0
%one = torch.constant.int 1
%zp = torch.constant.int -128
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32>
%14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list<int>
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
%16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
// CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
// CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
// CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32>
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
return %16 : !torch.vtensor<[1,3,7,7],f32>
}

View File

@ -0,0 +1,42 @@
// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s
// CHECK-LABEL: func.func @quantize_per_tensor
func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> {
%float = torch.constant.float 0.5
%zp = torch.constant.int 17
%min = torch.constant.int -128
%max = torch.constant.int 127
%dtype = torch.constant.int 1
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
// CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1
// CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8>
// CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]]
%0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8>
return %0 : !torch.vtensor<[1,3,8,8],si8>
}
// -----
// CHECK-LABEL: func.func @dequantize_per_tensor
func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> {
%float = torch.constant.float 0.5
%zp = torch.constant.int 17
%min = torch.constant.int -128
%max = torch.constant.int 127
%dtype = torch.constant.int 1
// CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17
// CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128
// CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127
// CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8>
// CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
%13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32>
return %13 : !torch.vtensor<[1,3,8,8],f32>
}