[TorchToLinalg] Adds Support for Remaining Quantized Matmul Cases (#3167)

The new cases added for quantized matmuls are:

1. vec-vec
2. vec-mat
3. mat-vec

each of which are now lowered to expand(s), quantized_matmul, and
collapse.
renxida-patch-1
zjgarvey 2024-04-16 11:28:28 -05:00 committed by GitHub
parent a0232e9ebd
commit 7a1ad0d7c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 115 additions and 11 deletions

View File

@ -327,11 +327,12 @@ public:
op, "unsupported: aten.matmul with different input element types");
}
Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType();
if (lhsZeroPoint) {
if (lhsRank < 2 || rhsRank < 2) {
return rewriter.notifyMatchFailure(
op, "unsupported: quantized aten.mm with vector");
}
// get each zero point ready to pass to a quantized_matmul
lhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(lhsZeroPoint.getType()),
@ -351,11 +352,61 @@ public:
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
}
Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType();
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
// expand/collapse + quantized_matmul
bool lhsVec = (lhsRank == 1 && rhsRank <= 2);
bool rhsVec = (lhsRank <= 2 && rhsRank == 1);
if (lhsVec || rhsVec) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation[0].push_back(0);
reassociation[0].push_back(1);
if (lhsVec) {
// unsqueeze lhs to a matrix
int64_t lhsDim = lhsType.getShape()[0];
auto lhsUnsqueezeType = RankedTensorType::get(
ArrayRef<int64_t>{1, lhsDim}, lhsType.getElementType());
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType,
lhs, reassociation);
}
if (rhsVec) {
// unsqueeze rhs to a matrix
int64_t rhsDim = rhsType.getShape()[0];
auto rhsUnsqueezeType = RankedTensorType::get(
ArrayRef<int64_t>{rhsDim, 1}, rhsType.getElementType());
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType,
rhs, reassociation);
}
// get quantized_matmul and squeeze result
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
Value zeroTensor = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value matmul = rewriter
.create<linalg::QuantizedMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
zeroTensor)
.getResult(0);
int64_t resultRank = resultType.getRank();
if (resultRank == 0) {
// in vec-vec case, need to collapse result to a scalar
reassociation.clear();
}
matmul = rewriter.create<tensor::CollapseShapeOp>(
loc, resultType, matmul, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
// the remaining quantized cases (Mat-Mat and broadcast -> BMM) are
// covered in the relevant section below
}
// The different cases of torch_matmul op is mentioned here:
// https://pytorch.org/docs/stable/generated/torch.matmul.html

View File

@ -323,6 +323,8 @@ TORCHDYNAMO_XFAIL_SET = {
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8VV_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8_basic",
"Conv2dQInt8Module_basic",
@ -1974,6 +1976,8 @@ ONNX_XFAIL_SET = {
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8VV_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",

View File

@ -14,7 +14,6 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
"QuantizedBatchedInputSingleLayer_basic",
"AtenMatmulQint8MV_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",

View File

@ -344,6 +344,56 @@ def AtenMmQMixedSigni8_basic(module, tu: TestUtils):
# ==============================================================================
class AtenMatmulQint8VM(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int8, True),
([-1,-1], torch.int8, True),
])
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
def AtenMatmulQint8VM_basic(module, tu: TestUtils):
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
tu.randint(9, 4, low=-128, high=127).to(torch.int8))
# ==============================================================================
class AtenMatmulQint8VV(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int8, True),
([-1], torch.int8, True),
])
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
def AtenMatmulQint8VV_basic(module, tu: TestUtils):
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
tu.randint(9, low=-128, high=127).to(torch.int8))
# ==============================================================================
class AtenMatmulQint8MV(torch.nn.Module):
def __init__(self):
@ -352,8 +402,8 @@ class AtenMatmulQint8MV(torch.nn.Module):
@export
@annotate_args([
None,
([3, 4], torch.int8, True),
([4], torch.int8, True),
([-1, -1], torch.int8, True),
([-1], torch.int8, True),
])
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)