mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Adds Support for Remaining Quantized Matmul Cases (#3167)
The new cases added for quantized matmuls are: 1. vec-vec 2. vec-mat 3. mat-vec each of which are now lowered to expand(s), quantized_matmul, and collapse.renxida-patch-1
parent
a0232e9ebd
commit
7a1ad0d7c0
|
@ -327,11 +327,12 @@ public:
|
|||
op, "unsupported: aten.matmul with different input element types");
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
auto resultType = cast<RankedTensorType>(newResultType);
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
if (lhsZeroPoint) {
|
||||
if (lhsRank < 2 || rhsRank < 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: quantized aten.mm with vector");
|
||||
}
|
||||
// get each zero point ready to pass to a quantized_matmul
|
||||
lhsZeroPoint = typeConverter->materializeTargetConversion(
|
||||
rewriter, loc,
|
||||
getTypeConverter()->convertType(lhsZeroPoint.getType()),
|
||||
|
@ -351,11 +352,61 @@ public:
|
|||
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
||||
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
||||
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
auto resultType = cast<RankedTensorType>(newResultType);
|
||||
Type elementType = resultType.getElementType();
|
||||
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
|
||||
// expand/collapse + quantized_matmul
|
||||
bool lhsVec = (lhsRank == 1 && rhsRank <= 2);
|
||||
bool rhsVec = (lhsRank <= 2 && rhsRank == 1);
|
||||
|
||||
if (lhsVec || rhsVec) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation[0].push_back(0);
|
||||
reassociation[0].push_back(1);
|
||||
|
||||
if (lhsVec) {
|
||||
// unsqueeze lhs to a matrix
|
||||
int64_t lhsDim = lhsType.getShape()[0];
|
||||
auto lhsUnsqueezeType = RankedTensorType::get(
|
||||
ArrayRef<int64_t>{1, lhsDim}, lhsType.getElementType());
|
||||
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType,
|
||||
lhs, reassociation);
|
||||
}
|
||||
if (rhsVec) {
|
||||
// unsqueeze rhs to a matrix
|
||||
int64_t rhsDim = rhsType.getShape()[0];
|
||||
auto rhsUnsqueezeType = RankedTensorType::get(
|
||||
ArrayRef<int64_t>{rhsDim, 1}, rhsType.getElementType());
|
||||
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType,
|
||||
rhs, reassociation);
|
||||
}
|
||||
// get quantized_matmul and squeeze result
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
Value zeroTensor = createZeroInitTensor(
|
||||
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
|
||||
Value matmul = rewriter
|
||||
.create<linalg::QuantizedMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
|
||||
zeroTensor)
|
||||
.getResult(0);
|
||||
int64_t resultRank = resultType.getRank();
|
||||
if (resultRank == 0) {
|
||||
// in vec-vec case, need to collapse result to a scalar
|
||||
reassociation.clear();
|
||||
}
|
||||
matmul = rewriter.create<tensor::CollapseShapeOp>(
|
||||
loc, resultType, matmul, reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
|
||||
return success();
|
||||
}
|
||||
// the remaining quantized cases (Mat-Mat and broadcast -> BMM) are
|
||||
// covered in the relevant section below
|
||||
}
|
||||
|
||||
// The different cases of torch_matmul op is mentioned here:
|
||||
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
|
|
|
@ -323,6 +323,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
"AtenMatmulQint8VV_basic",
|
||||
"AtenMatmulQint8VM_basic",
|
||||
"AtenMatmulQint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
||||
|
@ -1974,6 +1976,8 @@ ONNX_XFAIL_SET = {
|
|||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
"AtenMatmulQint8VV_basic",
|
||||
"AtenMatmulQint8VM_basic",
|
||||
"AtenMatmulQint8_basic",
|
||||
"AtenRealView128Module_basic",
|
||||
"AtenRealView64Module_basic",
|
||||
|
|
|
@ -14,7 +14,6 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"QuantizedMLP_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"QuantizedBatchedInputSingleLayer_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
|
|
@ -344,6 +344,56 @@ def AtenMmQMixedSigni8_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class AtenMatmulQint8VM(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int8, True),
|
||||
([-1,-1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
||||
def AtenMatmulQint8VM_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, 4, low=-128, high=127).to(torch.int8))
|
||||
|
||||
# ==============================================================================
|
||||
class AtenMatmulQint8VV(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int8, True),
|
||||
([-1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
||||
def AtenMatmulQint8VV_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(9, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(9, low=-128, high=127).to(torch.int8))
|
||||
|
||||
# ==============================================================================
|
||||
class AtenMatmulQint8MV(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -352,8 +402,8 @@ class AtenMatmulQint8MV(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int8, True),
|
||||
([4], torch.int8, True),
|
||||
([-1, -1], torch.int8, True),
|
||||
([-1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
|
|
Loading…
Reference in New Issue