mirror of https://github.com/llvm/torch-mlir
Update quantized matmul tests to DQ/Q format supported by fx_importer (#3815)
Continuation of https://github.com/llvm/torch-mlir/pull/3809 for the matmul tests.pull/3835/head
parent
1259e8a00a
commit
76209db5a5
|
@ -394,15 +394,6 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"AtenIntBoolOpModule_basic",
|
"AtenIntBoolOpModule_basic",
|
||||||
"AtenIntMM_basic",
|
"AtenIntMM_basic",
|
||||||
"AtenItemFpOpModule_basic",
|
"AtenItemFpOpModule_basic",
|
||||||
"AtenMatmulQMixedSigni8Transpose_basic",
|
|
||||||
"AtenMatmulQMixedSigni8_basic",
|
|
||||||
"AtenMatmulQint8MV_basic",
|
|
||||||
"AtenMatmulQint8_basic",
|
|
||||||
"AtenMatmulQint8VM_basic",
|
|
||||||
"AtenMatmulQint8VV_basic",
|
|
||||||
"AtenMmQMixedSigni8_basic",
|
|
||||||
"AtenMmQint8_basic",
|
|
||||||
"AtenMmQuint8_basic",
|
|
||||||
"QuantizedReluInt32_basic",
|
"QuantizedReluInt32_basic",
|
||||||
"QuantizedReluInt8_basic",
|
"QuantizedReluInt8_basic",
|
||||||
"QuantizedReluUint8_basic",
|
"QuantizedReluUint8_basic",
|
||||||
|
|
|
@ -337,6 +337,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
# For DQ-Q fake quantization ops
|
||||||
|
import torch.ao.quantization.fx._decomposed
|
||||||
|
|
||||||
|
|
||||||
class AtenMmQint8(torch.nn.Module):
|
class AtenMmQint8(torch.nn.Module):
|
||||||
|
@ -352,12 +354,14 @@ class AtenMmQint8(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.0215, -25, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.mm(qx, qy)
|
y, 0.0176, 18, -128, 127, torch.int8
|
||||||
return qz
|
)
|
||||||
|
z = torch.mm(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQint8())
|
@register_test_case(module_factory=lambda: AtenMmQint8())
|
||||||
|
@ -384,12 +388,14 @@ class AtenMmQuint8(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.199, 65, 0, 255, torch.uint8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.mm(qx, qy)
|
y, 0.0215, 160, 0, 255, torch.uint8
|
||||||
return qz
|
)
|
||||||
|
z = torch.mm(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
@register_test_case(module_factory=lambda: AtenMmQuint8())
|
||||||
|
@ -416,12 +422,14 @@ class AtenMmQMixedSigni8(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.03, -66, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.mm(qx, qy)
|
y, 0.025, 160, 0, 255, torch.uint8
|
||||||
return qz
|
)
|
||||||
|
z = torch.mm(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
@register_test_case(module_factory=lambda: AtenMmQMixedSigni8())
|
||||||
|
@ -475,12 +483,14 @@ class AtenMatmulQint8VM(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.0215, -25, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.matmul(qx, qy)
|
y, 0.0176, 18, -128, 127, torch.int8
|
||||||
return qz
|
)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8VM())
|
||||||
|
@ -505,12 +515,14 @@ class AtenMatmulQint8VV(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.0215, -25, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.matmul(qx, qy)
|
y, 0.0176, 18, -128, 127, torch.int8
|
||||||
return qz
|
)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8VV())
|
||||||
|
@ -535,12 +547,14 @@ class AtenMatmulQint8MV(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.0215, -25, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.matmul(qx, qy)
|
y, 0.0176, 18, -128, 127, torch.int8
|
||||||
return qz
|
)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8MV())
|
||||||
|
@ -565,12 +579,14 @@ class AtenMatmulQint8(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.0215, -25, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.matmul(qx, qy)
|
y, 0.0176, 18, -128, 127, torch.int8
|
||||||
return qz
|
)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
@register_test_case(module_factory=lambda: AtenMatmulQint8())
|
||||||
|
@ -597,12 +613,14 @@ class AtenMatmulQMixedSigni8(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.03, -66, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qz = torch.matmul(qx, qy)
|
y, 0.025, 160, 0, 255, torch.uint8
|
||||||
return qz
|
)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8())
|
||||||
|
@ -629,13 +647,15 @@ class AtenMatmulQMixedSigni8Transpose(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66)
|
x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qx = torch.dequantize(qx)
|
x, 0.03, -66, -128, 127, torch.int8
|
||||||
qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160)
|
)
|
||||||
qy = torch.dequantize(qy)
|
y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qy = torch.transpose(qy, 1, 2)
|
y, 0.025, 160, 0, 255, torch.uint8
|
||||||
qz = torch.matmul(qx, qy)
|
)
|
||||||
return qz
|
y = torch.transpose(y, 1, 2)
|
||||||
|
z = torch.matmul(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
@register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose())
|
||||||
|
|
Loading…
Reference in New Issue