Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815)

Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the
matmul tests.
pull/3835/head
Felix Schneider 2024-10-24 21:59:58 +02:00 committed by GitHub
parent 1259e8a00a
commit 76209db5a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 64 deletions

View File

@ -394,15 +394,6 @@ FX_IMPORTER_XFAIL_SET = {
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenItemFpOpModule_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8VV_basic",
"AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",

View File

@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
# ==============================================================================
# For DQ-Q fake quantization ops
import torch.ao.quantization.fx._decomposed
class AtenMmQint8(torch.nn.Module):
@ -352,12 +354,14 @@ class AtenMmQint8(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQint8())
@ -384,12 +388,14 @@ class AtenMmQuint8(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.199, 65, 0, 255, torch.uint8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0215, 160, 0, 255, torch.uint8
)
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQuint8())
@ -416,12 +422,14 @@ class AtenMmQMixedSigni8(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qz = torch.mm(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
z = torch.mm(x, y)
return z
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
@ -475,12 +483,14 @@ class AtenMatmulQint8VM(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
@ -505,12 +515,14 @@ class AtenMatmulQint8VV(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
@ -535,12 +547,14 @@ class AtenMatmulQint8MV(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
@ -565,12 +579,14 @@ class AtenMatmulQint8(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.0215, -25, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.0176, 18, -128, 127, torch.int8
)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQint8())
@ -597,12 +613,14 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
@ -629,13 +647,15 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
]
)
def forward(self, x, y):
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
qx = torch.dequantize(qx)
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
qy = torch.dequantize(qy)
qy = torch.transpose(qy, 1, 2)
qz = torch.matmul(qx, qy)
return qz
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
x, 0.03, -66, -128, 127, torch.int8
)
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
y, 0.025, 160, 0, 255, torch.uint8
)
y = torch.transpose(y, 1, 2)
z = torch.matmul(x, y)
return z
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())