mirror of https://github.com/llvm/torch-mlir
Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815)
Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the matmul tests.pull/3835/head
parent
1259e8a00a
commit
76209db5a5
|
@ -394,15 +394,6 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"AtenIntBoolOpModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
||||
"AtenMatmulQMixedSigni8_basic",
|
||||
"AtenMatmulQint8MV_basic",
|
||||
"AtenMatmulQint8_basic",
|
||||
"AtenMatmulQint8VM_basic",
|
||||
"AtenMatmulQint8VV_basic",
|
||||
"AtenMmQMixedSigni8_basic",
|
||||
"AtenMmQint8_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
|
|
|
@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
# ==============================================================================
|
||||
# For DQ-Q fake quantization ops
|
||||
import torch.ao.quantization.fx._decomposed
|
||||
|
||||
|
||||
class AtenMmQint8(torch.nn.Module):
|
||||
|
@ -352,12 +354,14 @@ class AtenMmQint8(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.0215, -25, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0176, 18, -128, 127, torch.int8
|
||||
)
|
||||
z = torch.mm(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQint8())
|
||||
|
@ -384,12 +388,14 @@ class AtenMmQuint8(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.199, 65, 0, 255, torch.uint8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0215, 160, 0, 255, torch.uint8
|
||||
)
|
||||
z = torch.mm(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||
|
@ -416,12 +422,14 @@ class AtenMmQMixedSigni8(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.mm(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.03, -66, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.025, 160, 0, 255, torch.uint8
|
||||
)
|
||||
z = torch.mm(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
||||
|
@ -475,12 +483,14 @@ class AtenMatmulQint8VM(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.0215, -25, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0176, 18, -128, 127, torch.int8
|
||||
)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
||||
|
@ -505,12 +515,14 @@ class AtenMatmulQint8VV(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.0215, -25, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0176, 18, -128, 127, torch.int8
|
||||
)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
||||
|
@ -535,12 +547,14 @@ class AtenMatmulQint8MV(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.0215, -25, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0176, 18, -128, 127, torch.int8
|
||||
)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
||||
|
@ -565,12 +579,14 @@ class AtenMatmulQint8(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.0215, -25, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.0176, 18, -128, 127, torch.int8
|
||||
)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
||||
|
@ -597,12 +613,14 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
||||
qy = torch.dequantize(qy)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.03, -66, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.025, 160, 0, 255, torch.uint8
|
||||
)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
||||
|
@ -629,13 +647,15 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x, y):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
||||
qx = torch.dequantize(qx)
|
||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
||||
qy = torch.dequantize(qy)
|
||||
qy = torch.transpose(qy, 1, 2)
|
||||
qz = torch.matmul(qx, qy)
|
||||
return qz
|
||||
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
x, 0.03, -66, -128, 127, torch.int8
|
||||
)
|
||||
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
y, 0.025, 160, 0, 255, torch.uint8
|
||||
)
|
||||
y = torch.transpose(y, 1, 2)
|
||||
z = torch.matmul(x, y)
|
||||
return z
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
||||
|
|
Loading…
Reference in New Issue