mirror of https://github.com/llvm/torch-mlir
[fx] Fix importing and tests for quantized conv (#3809)
The fx tracer does not support tracing "real" quantized tensors currently. A "real" quantized tensor here means a tensor that is created using a method like `torch.quantize_per_tensor()` and carries the quantization parameters (scale, zero_point, scheme) in the object. However, it seems like the DQ-Q type fake quantizatation is now commonly used as a high level representation of quantized operators and is only lowered to native quantized ops (if available) in the respective hardware backend. Quantization of floating point modules in PyTorch is recently also performed as a graph transformation after exporting/tracing the original module. ```python # Examples of "real"/native quantization tens = torch.randint(-127, 127, (1,), dtype=torch.int8) torch._make_per_tensor_quantized_tensor(tens, 1, 0) # tensor([90.], size=(1,), dtype=torch.qint8, # quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0) tens = torch.rand((1,)) torch.quantize_per_tensor(tens, 1, 0, torch.qint8) # tensor([1.], size=(1,), dtype=torch.qint8, # quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0) # Example of DQ/Q quantization import torch.ao.quantization.fx._decomposed tens = torch.rand((1,)) torch.ops.quantized_decomposed.quantize_per_tensor.default(tens, 1, 0, -128, 127, torch.int8) # tensor([1], dtype=torch.int8) ``` This means that a typical import flow for a quantized network into/through torch-mlir would look like this: `torch.export() -> quantization transformations on fx graph -> fx_importer` Where the tensors in the graph are normal float/int tensors and the quantization parameters are carried by the DQ/Q ops. These kinds of graphs can be traced without issues. Currently, our quantized convolution tests use the "real" quantized tensors. This means that with the retirement of the `jit_ir_importer`, these tests cannot be imported any longer. In summary, I see no reason to stick to the "real" quantization in these tests, as both PyTorch 2.0 is using DQ/Q quantization and our linalg backend is also using it. This patch updates our quantized convolution tests to use the DQ-Q quantization with the ops from `torch.ops.quantized_decomposed`. Note: For future reference, there seems to be an ongoing consolidation of the ops for the DQ/Q scheme on the PyTorch side (https://github.com/pytorch/ao/issues/986#issuecomment-2390296826).pull/3807/head
parent
140cad5659
commit
42ba541c68
|
@ -420,15 +420,7 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"CeilFloatModule_basic",
|
"CeilFloatModule_basic",
|
||||||
"ContainsIntList_False",
|
"ContainsIntList_False",
|
||||||
"ContainsIntList_True",
|
"ContainsIntList_True",
|
||||||
"Conv2dQInt8Module_basic",
|
|
||||||
"Conv2dQInt8Module_depthwise",
|
|
||||||
"Conv2dQInt8Module_grouped",
|
|
||||||
"Conv2dQInt8Module_not_depthwise",
|
|
||||||
"Conv2dQInt8PerChannelModule_basic",
|
|
||||||
"Conv2dQInt8PerChannelModule_depthwise",
|
|
||||||
"Conv2dQInt8PerChannelModule_grouped",
|
|
||||||
"ConvTbcModule_basic",
|
"ConvTbcModule_basic",
|
||||||
"ConvTranspose2DQInt8_basic",
|
|
||||||
"ConvolutionBackwardModule2DPadded_basic",
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
"ConvolutionBackwardModule2DStrided_basic",
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
"ConvolutionBackwardModule2D_basic",
|
"ConvolutionBackwardModule2D_basic",
|
||||||
|
|
|
@ -1183,23 +1183,28 @@ def ConvTbcModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
|
module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6))
|
||||||
|
|
||||||
|
|
||||||
|
# For DQ-Q fake quantization ops
|
||||||
|
import torch.ao.quantization.fx._decomposed
|
||||||
|
|
||||||
|
|
||||||
class Conv2dQInt8ModuleBase(torch.nn.Module):
|
class Conv2dQInt8ModuleBase(torch.nn.Module):
|
||||||
def __init__(self, groups=1):
|
def __init__(self, groups=1):
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _forward(self, inputVec, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
|
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
inputVec = torch.dequantize(inputVec)
|
input, 0.01, 7, -128, 127, torch.int8
|
||||||
|
)
|
||||||
|
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
|
weight, 0.01, 3, -128, 127, torch.int8
|
||||||
|
)
|
||||||
|
bias = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
|
bias, 1, 0, -1000, 1000, torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3)
|
conv = torch.ops.aten.conv2d(
|
||||||
weight = torch.dequantize(weight)
|
input,
|
||||||
|
|
||||||
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
|
|
||||||
bias = torch.dequantize(bias)
|
|
||||||
|
|
||||||
return torch.ops.aten.conv2d(
|
|
||||||
inputVec,
|
|
||||||
weight,
|
weight,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
stride=[1, 1],
|
stride=[1, 1],
|
||||||
|
@ -1208,6 +1213,11 @@ class Conv2dQInt8ModuleBase(torch.nn.Module):
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use int32 to avoid overflows
|
||||||
|
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||||
|
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
|
class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
|
||||||
@export
|
@export
|
||||||
|
@ -1216,7 +1226,7 @@ class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase):
|
||||||
None,
|
None,
|
||||||
([-1, -1, -1, -1], torch.int8, True),
|
([-1, -1, -1, -1], torch.int8, True),
|
||||||
([-1, -1, -1, -1], torch.int8, True),
|
([-1, -1, -1, -1], torch.int8, True),
|
||||||
([-1], torch.float, True),
|
([-1], torch.int32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, inputVec, weight, bias):
|
def forward(self, inputVec, weight, bias):
|
||||||
|
@ -1230,7 +1240,7 @@ class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase):
|
||||||
None,
|
None,
|
||||||
([2, 3, 12, 12], torch.int8, True),
|
([2, 3, 12, 12], torch.int8, True),
|
||||||
([3, 1, 5, 3], torch.int8, True),
|
([3, 1, 5, 3], torch.int8, True),
|
||||||
([3], torch.float, True),
|
([3], torch.int32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, inputVec, weight, bias):
|
def forward(self, inputVec, weight, bias):
|
||||||
|
@ -1244,7 +1254,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
|
||||||
None,
|
None,
|
||||||
([2, 3, 12, 12], torch.int8, True),
|
([2, 3, 12, 12], torch.int8, True),
|
||||||
([6, 1, 5, 3], torch.int8, True),
|
([6, 1, 5, 3], torch.int8, True),
|
||||||
([6], torch.float, True),
|
([6], torch.int32, True),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, inputVec, weight, bias):
|
def forward(self, inputVec, weight, bias):
|
||||||
|
@ -1255,7 +1265,7 @@ class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase):
|
||||||
def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
||||||
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
|
inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8)
|
||||||
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||||
bias = torch.rand(3)
|
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1263,7 +1273,7 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
|
||||||
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
||||||
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
|
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
|
||||||
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
|
||||||
bias = torch.rand(6)
|
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1271,7 +1281,7 @@ def Conv2dQInt8Module_grouped(module, tu: TestUtils):
|
||||||
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
|
def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
|
||||||
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
||||||
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
||||||
bias = torch.rand(3)
|
bias = tu.randint(3, low=-1000, high=1000).to(torch.int32)
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1281,7 +1291,7 @@ def Conv2dQInt8Module_depthwise(module, tu: TestUtils):
|
||||||
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
|
def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils):
|
||||||
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8)
|
||||||
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8)
|
||||||
bias = torch.rand(6)
|
bias = tu.randint(6, low=-1000, high=1000).to(torch.int32)
|
||||||
module.forward(inputVec, weight, bias)
|
module.forward(inputVec, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1300,16 +1310,17 @@ class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def forward(self, input, weight, bias):
|
def forward(self, input, weight, bias):
|
||||||
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
|
input = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qinput = torch.dequantize(qinput)
|
input, 0.01, -25, -128, 127, torch.int8
|
||||||
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
|
)
|
||||||
qweight = torch.dequantize(qweight)
|
weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
|
weight, 0.01, 50, -128, 127, torch.int8
|
||||||
qbias = torch.dequantize(qbias)
|
)
|
||||||
qz = torch.ops.aten.convolution(
|
|
||||||
qinput,
|
res = torch.ops.aten.convolution(
|
||||||
qweight,
|
input,
|
||||||
bias=qbias,
|
weight,
|
||||||
|
bias=bias,
|
||||||
stride=[2, 1],
|
stride=[2, 1],
|
||||||
padding=[1, 1],
|
padding=[1, 1],
|
||||||
dilation=[1, 1],
|
dilation=[1, 1],
|
||||||
|
@ -1317,7 +1328,11 @@ class ConvTranspose2DQInt8Module(torch.nn.Module):
|
||||||
output_padding=[0, 0],
|
output_padding=[0, 0],
|
||||||
groups=1,
|
groups=1,
|
||||||
)
|
)
|
||||||
return qz
|
|
||||||
|
# Use int32 to avoid overflows
|
||||||
|
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||||
|
res, 1, 0, -(2**31), 2**31 - 1, torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
|
||||||
|
@ -1342,18 +1357,14 @@ class Conv2dQInt8PerChannelModuleBase(torch.nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def _forward(self, inputVec, weight, scales, zeropoints, bias):
|
def _forward(self, inputVec, weight, scales, zeropoints, bias):
|
||||||
inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7)
|
inputVec = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||||
inputVec = torch.dequantize(inputVec)
|
inputVec, 0.01, 7, -128, 127, torch.int8
|
||||||
|
)
|
||||||
weight = torch._make_per_channel_quantized_tensor(
|
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
|
||||||
weight, scales, zeropoints, axis=0
|
weight, scales, zeropoints, 0, -128, 127, torch.int8
|
||||||
)
|
)
|
||||||
weight = torch.dequantize(weight)
|
|
||||||
|
|
||||||
bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
|
conv = torch.ops.aten.conv2d(
|
||||||
bias = torch.dequantize(bias)
|
|
||||||
|
|
||||||
return torch.ops.aten.conv2d(
|
|
||||||
inputVec,
|
inputVec,
|
||||||
weight,
|
weight,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
@ -1363,6 +1374,11 @@ class Conv2dQInt8PerChannelModuleBase(torch.nn.Module):
|
||||||
groups=self.groups,
|
groups=self.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use int32 to avoid overflows
|
||||||
|
return torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||||
|
conv, 1, 0, -(2**31), 2**31 - 1, torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
|
class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase):
|
||||||
@export
|
@export
|
||||||
|
|
|
@ -41,7 +41,7 @@ def _module_lowering(
|
||||||
option_string = "{extra-library=" + extra_library_file_name + "}"
|
option_string = "{extra-library=" + extra_library_file_name + "}"
|
||||||
run_pipeline_with_repro_report(
|
run_pipeline_with_repro_report(
|
||||||
torch_mod,
|
torch_mod,
|
||||||
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline{option_string})",
|
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
|
||||||
"Lowering TorchFX IR -> Torch Backend IR",
|
"Lowering TorchFX IR -> Torch Backend IR",
|
||||||
enable_ir_printing=verbose,
|
enable_ir_printing=verbose,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue