Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815)

Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the
matmul tests.
pull/3835/head
Felix Schneider 2024-10-24 21:59:58 +02:00 committed by GitHub
parent 1259e8a00a
commit 76209db5a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 64 deletions

View File

@ -394,15 +394,6 @@ FX_IMPORTER_XFAIL_SET = {
"AtenIntBoolOpModule_basic", "AtenIntBoolOpModule_basic",
"AtenIntMM_basic", "AtenIntMM_basic",
"AtenItemFpOpModule_basic", "AtenItemFpOpModule_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8VV_basic",
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"QuantizedReluInt32_basic", "QuantizedReluInt32_basic",
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",

View File

@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
# For DQ-Q fake quantization ops
import torch.ao.quantization.fx._decomposed
class AtenMmQint8(torch.nn.Module): class AtenMmQint8(torch.nn.Module):
@ -352,12 +354,14 @@ class AtenMmQint8(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.0215, -25, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.mm(qx, qy) y, 0.0176, 18, -128, 127, torch.int8
return qz )
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQint8()) @register_test_case(module_factory=lambda: AtenMmQint8())
@ -384,12 +388,14 @@ class AtenMmQuint8(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.199, 65, 0, 255, torch.uint8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.mm(qx, qy) y, 0.0215, 160, 0, 255, torch.uint8
return qz )
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQuint8()) @register_test_case(module_factory=lambda: AtenMmQuint8())
@ -416,12 +422,14 @@ class AtenMmQMixedSigni8(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.03, -66, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.mm(qx, qy) y, 0.025, 160, 0, 255, torch.uint8
return qz )
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) @register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
@ -475,12 +483,14 @@ class AtenMatmulQint8VM(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.0215, -25, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.matmul(qx, qy) y, 0.0176, 18, -128, 127, torch.int8
return qz )
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8VM()) @register_test_case(module_factory=lambda: AtenMatmulQint8VM())
@ -505,12 +515,14 @@ class AtenMatmulQint8VV(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.0215, -25, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.matmul(qx, qy) y, 0.0176, 18, -128, 127, torch.int8
return qz )
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8VV()) @register_test_case(module_factory=lambda: AtenMatmulQint8VV())
@ -535,12 +547,14 @@ class AtenMatmulQint8MV(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.0215, -25, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.matmul(qx, qy) y, 0.0176, 18, -128, 127, torch.int8
return qz )
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8MV()) @register_test_case(module_factory=lambda: AtenMatmulQint8MV())
@ -565,12 +579,14 @@ class AtenMatmulQint8(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.0215, -25, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.matmul(qx, qy) y, 0.0176, 18, -128, 127, torch.int8
return qz )
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8()) @register_test_case(module_factory=lambda: AtenMatmulQint8())
@ -597,12 +613,14 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.03, -66, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qz = torch.matmul(qx, qy) y, 0.025, 160, 0, 255, torch.uint8
return qz )
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
@ -629,13 +647,15 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
] ]
) )
def forward(self, x, y): def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx = torch.dequantize(qx) x, 0.03, -66, -128, 127, torch.int8
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) )
qy = torch.dequantize(qy) y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qy = torch.transpose(qy, 1, 2) y, 0.025, 160, 0, 255, torch.uint8
qz = torch.matmul(qx, qy) )
return qz y = torch.transpose(y, 1, 2)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())