Generalize Operand Quantization in FuseQuantizeOps (#3327)

This change enables more customization with operand quantization, and
generalizes the patterns QuantizeOperands and QuantizeTransposeOperands
to QuantizeOperandsPastCommutingOps.

This allows for passing quantization through operations which are
functionally unaffected by quantization, such as view-like ops. The
purpose of this change is to address a myriad of quantization issues
seen in quantized onnx models that have some reshape-like operations
sandwiched in between a dequant and something like a matmul (whose other
operand is immediately quantizable).
pull/3331/head
zjgarvey 2024-05-12 22:49:59 -05:00 committed by GitHub
parent 0b7cbf5e60
commit 75d1d72059
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 169 additions and 98 deletions

View File

@ -13,6 +13,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <stack>
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -27,98 +28,112 @@ template <typename SrcOp> struct QuantInfo {
template <> struct QuantInfo<AtenReluOp> { template <> struct QuantInfo<AtenReluOp> {
static constexpr unsigned operandsToQuantize[1] = {0}; static constexpr unsigned operandsToQuantize[1] = {0};
}; };
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> { // A QCommutingOp is an Op satisfying:
// 1. Has at most one tensor operand at index 0
// 2. Has a single output, which is a tensor
// 3. Satisfies the commutation relation:
// [MPTQT -> Dequant -> Op(float)] = [Op(int) -> MPTQT -> Dequant]
// where MPTQT = "Aten_MakePerTensorQuantizedTensorOp"
// and Dequant = "AtenDequantizeSelfOp" or "AtenDequantizeTensorOp"
bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp>(op);
}
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
// -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... ->
// Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops
// {Op1,Op2,...,Opk} with k <= depth.
// With depth = 0, this conversion will simply fuse any immediately quantizable
// operands: [MPTQT -> Dequant -> SrcOp (float operands)] to [MPTQT -> SrcOp(int
// operands)]
template <typename SrcOp, unsigned depth>
class QuantizeOperandsPastCommutingOps : public OpRewritePattern<SrcOp> {
public: public:
using OpRewritePattern<SrcOp>::OpRewritePattern; using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op, LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
mlir::Location loc = op.getLoc();
llvm::SmallVector<Value> operands(op->getOperands());
bool dequanted = false; bool dequanted = false;
auto f = [&dequanted](Value operand) {
if (auto dequant = operand.getDefiningOp<AtenDequantizeTensorOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
if (auto dequant = operand.getDefiningOp<AtenDequantizeSelfOp>()) {
operand = dequant.getOperand();
dequanted = true;
}
return operand;
};
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) { for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
operands[i] = f(operands[i]); Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
if (!currOp)
break;
// Case 1 : currOp is a q commuting op (continue loop)
if (isQCommutingOp(currOp)) {
commutingOpStack.push(currOp);
// set operand to currOp for next k-iteration
operand = currOp->getOperand(0);
continue;
}
// Case 2 : currOp is a dequant op (end loop)
if (llvm::isa<AtenDequantizeSelfOp, AtenDequantizeTensorOp>(currOp)) {
dequantOpd = currOp->getOperand(0);
auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
}
// either a dequant was found or chain broken, so break loop
break;
} }
if (!dequanted) { // move to next operand if this trace was unsuccessful
return rewriter.notifyMatchFailure(op, "no dequantizations found"); if (!MPTQTOpd)
} continue;
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands); // a successful trace occured, so set dequant to true
return success();
}
};
template <typename SrcOp>
class QuantizeTransposedOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
unsigned numOperands = operands.size();
bool dequanted = false;
for (unsigned i = 0; i < numOperands; i++) {
if (auto trans = operands[i].getDefiningOp<AtenTransposeIntOp>()) {
auto transOperands = trans.getOperands();
Value dequantOperand;
if (auto dequant =
transOperands[0].getDefiningOp<AtenDequantizeSelfOp>()) {
dequantOperand = dequant.getOperand();
if (auto quant =
dequantOperand
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
auto quantOperands = quant.getOperands();
auto qType = quantOperands[0]
.getType()
.cast<ValueTensorType>()
.getOptionalDtype();
auto torchQType =
cast<ValueTensorType>(quant.getType()).getOptionalDtype();
auto transQTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
qType);
auto newQuantTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
torchQType);
Value newTrans = rewriter.create<AtenTransposeIntOp>(
op.getLoc(), transQTy, quantOperands[0], transOperands[1],
transOperands[2]);
Value newQuant =
rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), newQuantTy, newTrans, quantOperands[1],
quantOperands[2]);
operands[i] = newQuant;
dequanted = true; dequanted = true;
// rewrite stack
Value oldOpd = MPTQTOpd;
Type intDType =
cast<ValueTensorType>(MPTQTOpd.getType()).getOptionalDtype();
while (!commutingOpStack.empty()) {
// get front of the commuting op stack and replace its first operand
// with oldOpd
auto currOp = commutingOpStack.top();
commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// get new result type
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
auto intType =
rewriter.getType<ValueTensorType>(oldType.getSizes(), intDType);
// rewrite currOp to have new operands and result type
// store this as oldOpd for next loop
oldOpd = rewriter
.create(loc, (currOp->getName()).getIdentifier(),
currOperands, intType, currOp->getAttrs())
->getResult(0);
} }
// stack is empty, so oldOpd is now the corrected verion of the
// SrcOp's original operand
// convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp
auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands();
auto qTorchType =
cast<ValueTensorType>(dequantOpd.getType()).getOptionalDtype();
auto newMPTQTType = rewriter.getType<ValueTensorType>(
cast<ValueTensorType>(operands[i].getType()).getSizes(), qTorchType);
operands[i] = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]);
} }
}
}
if (!dequanted) { if (!dequanted) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op, "No dequantizations found.");
op, "no dequantized transpose inputs found.");
} }
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands); rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success(); return success();
} }
@ -356,11 +371,13 @@ public:
RemoveUnused<AtenDequantizeTensorOp>, RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>, RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>, RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>, RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>, RemoveUnused<AtenReshapeOp>,
QuantizeTransposedOperands<AtenMatmulOp>, QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 0>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>, QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>, QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
QuantizeOperandsPastCommutingOps<AtenMmOp, 1>,
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>( QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context); context);

View File

@ -28,6 +28,60 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si
// ----- // -----
// CHECK-LABEL: @matmul_commuting
func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.vtensor<[1,1024,1024],f32> {
%float5.000000e-01 = torch.constant.float 5.000000e-01
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int-128 = torch.constant.int -128
%int2 = torch.constant.int 2
%int128 = torch.constant.int 128
%int1024 = torch.constant.int 1024
%int12 = torch.constant.int 12
%0 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float5.000000e-01, %int-128 : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8>
%1 = torch.aten.dequantize.self %0 : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32>
%2 = torch.aten.slice.Tensor %1, %int0, %int0, %int1, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32>
%3 = torch.aten.slice.Tensor %1, %int0, %int1, %int2, %int1 : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32>
%4 = torch.prim.ListConstruct %int1, %int128, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.reshape %2, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,1024],f32>
%6 = torch.aten.reshape %3, %4 : !torch.vtensor<[1,128,32,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,1024],f32>
%7 = torch.aten.transpose.int %5, %int1, %int2 : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32>
%8 = torch.aten.quantize_per_tensor %7, %float5.000000e-01, %int0, %int12 : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8>
%9 = torch.aten.int_repr %8 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8>
%10 = torch.aten._make_per_tensor_quantized_tensor %9, %float5.000000e-01, %int0 : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8>
%11 = torch.aten.dequantize.self %10 : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],f32>
%12 = torch.aten.matmul %11, %6 : !torch.vtensor<[1,1024,128],f32>, !torch.vtensor<[1,128,1024],f32> -> !torch.vtensor<[1,1024,1024],f32>
// CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[IN128:.+]] = torch.constant.int -128
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[I128:.+]] = torch.constant.int 128
// CHECK-DAG: %[[I1024:.+]] = torch.constant.int 1024
// CHECK-DAG: %[[I12:.+]] = torch.constant.int 12
// CHECK-DAG: %[[MPTQT0:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[IN128]] : !torch.vtensor<[2,128,32,32],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,128,32,32],!torch.qint8>
// CHECK-DAG: %[[DQ0:.+]] = torch.aten.dequantize.self %[[MPTQT0]] : !torch.vtensor<[2,128,32,32],!torch.qint8> -> !torch.vtensor<[2,128,32,32],f32>
// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %[[DQ0]], %[[I0]], %[[I0]], %[[I1]], %[[I1]] : !torch.vtensor<[2,128,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],f32>
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I1]], %[[I128]], %[[I1024]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK-DAG: %[[RESHAPE0:.+]] = torch.aten.reshape %[[SLICE0]], %[[LIST]] : !torch.vtensor<[1,128,32,32],f32>, !torch.list<int> -> !torch.vtensor<[1,128,1024],f32>
// CHECK-DAG: %[[TR0:.+]] = torch.aten.transpose.int %[[RESHAPE0]], %[[I1]], %[[I2]] : !torch.vtensor<[1,128,1024],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],f32>
// CHECK-DAG: %[[Q0:.+]] = torch.aten.quantize_per_tensor %[[TR0]], %[[HALF]], %[[I0]], %[[I12]] : !torch.vtensor<[1,1024,128],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8>
// CHECK-DAG: %[[IR0:.+]] = torch.aten.int_repr %[[Q0]] : !torch.vtensor<[1,1024,128],!torch.qint8> -> !torch.vtensor<[1,1024,128],si8>
// CHECK-DAG: %[[MPTQT1:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR0]], %[[HALF]], %[[I0]] : !torch.vtensor<[1,1024,128],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,128],!torch.qint8>
// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[I0]], %[[I1]], %[[I2]], %[[I1]] : !torch.vtensor<[2,128,32,32],si8>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,128,32,32],si8>
// CHECK-DAG: %[[RESHAPE1:.+]] = torch.aten.reshape %[[SLICE1]], %[[LIST]] : !torch.vtensor<[1,128,32,32],si8>, !torch.list<int> -> !torch.vtensor<[1,128,1024],si8>
// CHECK-DAG: %[[MPTQT2:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[RESHAPE1]], %[[HALF]], %[[IN128]] : !torch.vtensor<[1,128,1024],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,128,1024],!torch.qint8>
// CHECK-DAG: %[[MATMUL:.+]] = torch.aten.matmul %[[MPTQT1]], %[[MPTQT2]] : !torch.vtensor<[1,1024,128],!torch.qint8>, !torch.vtensor<[1,128,1024],!torch.qint8> -> !torch.vtensor<[1,1024,1024],!torch.qint32>
// CHECK-DAG: %[[IR1:.+]] = torch.aten.int_repr %[[MATMUL]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],si32>
// CHECK-DAG: %[[MPTQT3:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[IR1]], %[[QUARTER]], %[[I0]] : !torch.vtensor<[1,1024,1024],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,1024,1024],!torch.qint32>
// CHECK-DAG: %[[DQ1:.+]] = torch.aten.dequantize.tensor %[[MPTQT3]] : !torch.vtensor<[1,1024,1024],!torch.qint32> -> !torch.vtensor<[1,1024,1024],f32>
return %12 : !torch.vtensor<[1,1024,1024],f32>
}
// -----
// CHECK-LABEL: @convolution_bias // CHECK-LABEL: @convolution_bias
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
%scale = torch.constant.float 0.5 %scale = torch.constant.float 0.5
@ -43,21 +97,21 @@ func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int> %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
%16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32> %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>
// CHECK: %[[DTYPE:.+]] = torch.constant.int 14 // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14
// CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
// CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
// CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
// CHECK: %[[ONE:.+]] = torch.constant.int 1 // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
// CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
// CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
// CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int> // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32>
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32>
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],si32> // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],si32>
// CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> // CHECK-DAG: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
return %16 : !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32>
} }