mirror of https://github.com/llvm/torch-mlir
[onnx] Lower `onnx.QLinearMatMul` lowering to `torch` operators (#2776)
We can plumb the linear matmul into pytorch using its quantized types with side channel information. To handle the final int8 operation we dequantize and requantize.pull/2865/head
parent
894805dd5e
commit
60bf6c25af
|
@ -55,7 +55,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value zeropoint = operands[2];
|
||||
|
||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
||||
if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||
if (!scaleTy || !scaleTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"requires known rank");
|
||||
if (!resultType.hasDtype())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "requires known result dtype");
|
||||
|
@ -89,9 +91,135 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
}
|
||||
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"QLinearMatMul", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
llvm::SmallVector<Value> operands;
|
||||
if (binder.tensorOperands(operands, 8) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value a = operands[0];
|
||||
Value aScale = operands[1];
|
||||
Value aZp = operands[2];
|
||||
Value b = operands[3];
|
||||
Value bScale = operands[4];
|
||||
Value bZp = operands[5];
|
||||
Value cScale = operands[6];
|
||||
Value cZp = operands[7];
|
||||
|
||||
}
|
||||
);
|
||||
auto check = [](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
for (auto dim : vTy.getSizes())
|
||||
if (dim != 1)
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
||||
!check(cScale) || !check(cScale))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "not supported for non per-tensor quantization");
|
||||
|
||||
Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(
|
||||
rewriter.getType<Torch::IntType>()),
|
||||
ValueRange{});
|
||||
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
if (!vTy.getSizes().empty()) {
|
||||
vTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
||||
v = rewriter.create<Torch::AtenReshapeOp>(binder.getLoc(), vTy, v,
|
||||
emptyList);
|
||||
}
|
||||
|
||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||
if (isa<IntegerType>(vTy.getDtype()))
|
||||
extractTy = rewriter.getType<Torch::IntType>();
|
||||
|
||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
||||
v);
|
||||
};
|
||||
|
||||
aZp = extract(aZp);
|
||||
bZp = extract(bZp);
|
||||
cZp = extract(cZp);
|
||||
aScale = extract(aScale);
|
||||
bScale = extract(bScale);
|
||||
cScale = extract(cScale);
|
||||
|
||||
auto getQTy =
|
||||
[&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType {
|
||||
auto dt = ty.getDtype();
|
||||
Type newDt;
|
||||
if (dt.isUnsignedInteger(8)) {
|
||||
newDt = rewriter.getType<Torch::QUInt8Type>();
|
||||
} else if (dt.isSignedInteger(8)) {
|
||||
newDt = rewriter.getType<Torch::QInt8Type>();
|
||||
} else if (dt.isSignedInteger(32)) {
|
||||
newDt = rewriter.getType<Torch::QInt32Type>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return rewriter.getType<Torch::ValueTensorType>(ty.getOptionalSizes(),
|
||||
newDt);
|
||||
};
|
||||
|
||||
auto make = [&rewriter, &binder, &getQTy](Value v, Value scale,
|
||||
Value zp) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto newTy = getQTy(ty);
|
||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), newTy, v, scale, zp);
|
||||
};
|
||||
|
||||
a = make(a, aScale, aZp);
|
||||
b = make(b, bScale, bZp);
|
||||
|
||||
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(),
|
||||
rewriter.getIntegerType(32, /*issigned=*/true));
|
||||
|
||||
Value c;
|
||||
if (cTy.getSizes().size() == 2) {
|
||||
c = rewriter.create<Torch::AtenMmOp>(binder.getLoc(), cTy, a, b);
|
||||
} else {
|
||||
c = rewriter.create<Torch::AtenBmmOp>(binder.getLoc(), cTy, a, b);
|
||||
}
|
||||
|
||||
cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(),
|
||||
rewriter.getType<Torch::QInt32Type>());
|
||||
|
||||
Value mmScale = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
||||
bScale);
|
||||
Value mmZp = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), cTy, c, mmScale, mmZp);
|
||||
cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(), rewriter.getF32Type());
|
||||
|
||||
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
||||
c);
|
||||
cTy = getQTy(resultType);
|
||||
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64),
|
||||
static_cast<int64_t>(
|
||||
Torch::getScalarTypeForType(cTy.getDtype()))));
|
||||
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
||||
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
||||
c);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Reciprocal", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
// level constants. This is a pragmatic choice which lets us have a lot
|
||||
// of tests in this file, whereas the others tend to be more bespoke.
|
||||
|
||||
|
||||
// CHECK-LABEL: @test_quantizelinear_si8
|
||||
func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8>
|
||||
|
@ -48,6 +47,70 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_qlinearmatmul_2D
|
||||
func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8>
|
||||
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4],!torch.quint8>
|
||||
// CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8>
|
||||
// CHECK: %[[MM:.+]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,4],!torch.quint8>, !torch.vtensor<[4,3],!torch.quint8> -> !torch.vtensor<[2,3],si32>
|
||||
// CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,3],!torch.qint32>
|
||||
// CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,3],!torch.qint32> -> !torch.vtensor<[2,3],f32>
|
||||
// CHECK: %[[DTY:.+]] = torch.constant.int 13
|
||||
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,3],!torch.quint8>
|
||||
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,3],!torch.quint8> -> !torch.vtensor<[2,3],ui8>
|
||||
// CHECK: return %[[OUT]]
|
||||
return %0 : !torch.vtensor<[2,3],ui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_qlinearmatmul_3D
|
||||
func.func @test_qlinearmatmul_3D(%arg0: !torch.vtensor<[2,2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[2,4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[2,4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8>
|
||||
// CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list<int> -> !torch.vtensor<[],ui8>
|
||||
// CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
// CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,2,4],!torch.quint8>
|
||||
// CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[2,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4,3],!torch.quint8>
|
||||
// CHECK: %[[MM:.+]] = torch.aten.bmm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,2,4],!torch.quint8>, !torch.vtensor<[2,4,3],!torch.quint8> -> !torch.vtensor<[2,2,3],si32>
|
||||
// CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,2,3],!torch.qint32>
|
||||
// CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,2,3],!torch.qint32> -> !torch.vtensor<[2,2,3],f32>
|
||||
// CHECK: %[[DTY:.+]] = torch.constant.int 13
|
||||
// CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,2,3],!torch.quint8>
|
||||
// CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,2,3],!torch.quint8> -> !torch.vtensor<[2,2,3],ui8>
|
||||
// CHECK: return %[[OUT]]
|
||||
return %0 : !torch.vtensor<[2,2,3],ui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_reciprocal
|
||||
func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
|
|
Loading…
Reference in New Issue