From 76209db5a5817e098cfced7f065a0f54e6b09d13 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Thu, 24 Oct 2024 21:59:58 +0200 Subject: [PATCH] Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815) Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the matmul tests. --- projects/pt1/e2e_testing/xfail_sets.py | 9 -- .../torch_mlir_e2e_test/test_suite/matmul.py | 130 ++++++++++-------- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ab5c54b76..553a27924 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -394,15 +394,6 @@ FX_IMPORTER_XFAIL_SET = { "AtenIntBoolOpModule_basic", "AtenIntMM_basic", "AtenItemFpOpModule_basic", - "AtenMatmulQMixedSigni8Transpose_basic", - "AtenMatmulQMixedSigni8_basic", - "AtenMatmulQint8MV_basic", - "AtenMatmulQint8_basic", - "AtenMatmulQint8VM_basic", - "AtenMatmulQint8VV_basic", - "AtenMmQMixedSigni8_basic", - "AtenMmQint8_basic", - "AtenMmQuint8_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 40e6a7359..17240cf95 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== +# For DQ-Q fake quantization ops +import torch.ao.quantization.fx._decomposed class AtenMmQint8(torch.nn.Module): @@ -352,12 +354,14 @@ class AtenMmQint8(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQint8()) @@ -384,12 +388,14 @@ class AtenMmQuint8(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.199, 65, 0, 255, torch.uint8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0215, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQuint8()) @@ -416,12 +422,14 @@ class AtenMmQMixedSigni8(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) @@ -475,12 +483,14 @@ class AtenMatmulQint8VM(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VM()) @@ -505,12 +515,14 @@ class AtenMatmulQint8VV(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VV()) @@ -535,12 +547,14 @@ class AtenMatmulQint8MV(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8MV()) @@ -565,12 +579,14 @@ class AtenMatmulQint8(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8()) @@ -597,12 +613,14 @@ class AtenMatmulQMixedSigni8(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) @@ -629,13 +647,15 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qy = torch.transpose(qy, 1, 2) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + y = torch.transpose(y, 1, 2) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())