mirror of https://github.com/llvm/torch-mlir
[torch] Fix clamp ranges on quantize_per_tensor on unsigned (#3018)
SExtValue was used for `int` and `uint` clamp values. This caused the result to always be outputed as `zero`.pull/3047/head
parent
cb5cb506df
commit
3a56714bff
|
@ -1504,10 +1504,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
APInt max = isUnsigned ? APInt::getMaxValue(bitwidth)
|
||||
: APInt::getSignedMaxValue(bitwidth);
|
||||
|
||||
Value minVal = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(valueTy, min.getSExtValue()));
|
||||
Value maxVal = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(valueTy, max.getSExtValue()));
|
||||
double minI = isUnsigned ? static_cast<double>(min.getZExtValue())
|
||||
: static_cast<double>(min.getSExtValue());
|
||||
double maxI = isUnsigned ? static_cast<double>(max.getZExtValue())
|
||||
: static_cast<double>(max.getSExtValue());
|
||||
Value minVal =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(valueTy, minI));
|
||||
Value maxVal =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(valueTy, maxI));
|
||||
Value minCmp =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, value, minVal);
|
||||
Value maxCmp =
|
||||
|
|
|
@ -50,7 +50,7 @@ public:
|
|||
op.getLoc(), operandTy.getElementType(), operand, indices);
|
||||
auto extractTy = extract.getType();
|
||||
if (isa<mlir::IntegerType>(extractTy) && !extractTy.isInteger(64)) {
|
||||
if (torchDTy.isSignlessInteger()) {
|
||||
if (torchDTy.isUnsignedInteger()) {
|
||||
extract = rewriter.create<arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getIntegerType(64), extract);
|
||||
} else {
|
||||
|
|
|
@ -296,6 +296,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
||||
|
@ -1619,6 +1620,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseOrTensorModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||
"ElementwiseRemainderTensorModule_Int_basic",
|
||||
"ElementwiseFmodTensor_Int_basic",
|
||||
"EmptyStridedModule_basic",
|
||||
|
|
|
@ -4675,6 +4675,31 @@ class ElementwiseQuantizePerTensorModule(torch.nn.Module):
|
|||
def ElementwiseQuantizePerTensorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseQuantizePerTensorUIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
scale = 0.04
|
||||
zp = 11
|
||||
dtype = torch.quint8
|
||||
# We return the int representation as we can not map to quint8 type yet on boundaries.
|
||||
q = torch.quantize_per_tensor(x, scale, zp, dtype).int_repr()
|
||||
q = q.to(torch.int8)
|
||||
return q
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorUIntModule())
|
||||
def ElementwiseQuantizePerTensorUIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue