From a8ba865fcab6475ff58c2beb14a1823fc25314c2 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:01:36 -0500 Subject: [PATCH] [torch] Adds Quantization Support for `aten.relu` (#3177) A choice was made to quantize the return type of Relu with a scale and zero point copied from the input's quantization scheme. With this choice, the torch-to-linalg conversion of quantized Relu essentially computes max(input, zeroPoint) in the elementwise payload. --- .../TorchToLinalg/Uncategorized.cpp | 78 +++++++++++++++--- .../Torch/Transforms/FuseQuantizedOps.cpp | 81 ++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 9 +++ .../test_suite/elementwise.py | 63 +++++++++++++++ 4 files changed, 216 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 441c76ce7..3c5d6cfae 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, llvm_unreachable("Unhandled element type for comparison"); } +static Value getZeroPoint(Value value) { + if (auto make = value.getDefiningOp()) { + return make.getZeroPoint(); + } + return nullptr; +} + static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, Value lhs, Value rhs) { return createComparisonTemplate(op)) { - if (!relu.getType() - .cast() - .getDtype() - .isa()) { - relu.emitError("unimplemented: non-floating point dtype"); + Value zeroPoint = getZeroPoint(relu.getSelf()); + Value arg = payloadArgs[0]; + auto intType = arg.getType().dyn_cast(); + if (zeroPoint && !intType) { + relu.emitError("unimplemented: non-integer quantized Relu."); return nullptr; } - Type elementType = payloadArgs[0].getType(); - Value constZero = - b.create(loc, b.getZeroAttr(elementType)); - Value pred = b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], constZero); - return b.create(loc, pred, payloadArgs[0], constZero); + auto reluTorchType = cast(relu.getType()); + bool isUnsigned = + torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype()); + if (zeroPoint) { + int64_t zeroPointInt; + int64_t width = intType.getWidth(); + assert(width < 64); + int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1)); + int64_t maxForIntType = + isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1; + // check for constant zero point edge-cases: + if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) { + if (zeroPointInt > maxForIntType) { + // TODO: figure out how to handle this case: + // current impl. quantizes output like input. + // If zero point > maxForIntType, ordinary relu should return 0. + // However, 0 isn't represented in such a quantization scheme. + relu.emitError( + "unimplemented: quantized relu for zero-point > max qint"); + return nullptr; + } + if (zeroPointInt < minForIntType) + return arg; + } + zeroPoint = converter->materializeTargetConversion( + b, loc, converter->convertType(zeroPoint.getType()), zeroPoint); + auto minForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); + auto maxForIntTypeValue = b.create( + loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); + auto zpLtMax = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, maxForIntTypeValue); + b.create( + loc, zpLtMax, + b.getStringAttr("Invalid Quantization: quantized relu with " + "zero-point > max qint")); + auto zpLtMin = b.create(loc, arith::CmpIPredicate::slt, + zeroPoint, minForIntTypeValue); + zeroPoint = b.create(loc, zpLtMin, minForIntTypeValue, + zeroPoint); + zeroPoint = b.create(loc, arg.getType(), zeroPoint); + } else { + zeroPoint = + b.create(loc, b.getZeroAttr(arg.getType())); + } + Value cmp; + if (intType) { + auto pred = + isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; + cmp = b.create(loc, pred, arg, zeroPoint); + } else { + cmp = b.create(loc, arith::CmpFPredicate::UGT, arg, + zeroPoint); + } + return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { if (!round.getType() diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index bff463c4c..3b30e9424 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -20,6 +20,13 @@ using namespace mlir::torch::Torch; namespace { +template struct QuantInfo { + static constexpr unsigned operandsToQuantize[2] = {0, 1}; +}; + +template <> struct QuantInfo { + static constexpr unsigned operandsToQuantize[1] = {0}; +}; template class QuantizeOperands : public OpRewritePattern { public: @@ -42,8 +49,9 @@ public: return operand; }; - operands[0] = f(operands[0]); - operands[1] = f(operands[1]); + for (unsigned i : QuantInfo::operandsToQuantize) { + operands[i] = f(operands[i]); + } if (!dequanted) { return rewriter.notifyMatchFailure(op, "no dequantizations found"); @@ -259,6 +267,70 @@ public: } }; +// Use for ops which do not manipulate scale/zero point of an input. +template +class QuantizeResultLikeOperand : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + Value input = operands[0]; + + auto inputType = dyn_cast_or_null(input.getType()); + if (!inputType || !inputType.hasDtype()) + return failure(); + auto qDtype = inputType.getDtype(); + + auto resultTy = dyn_cast_or_null(op.getType()); + if (!resultTy || !resultTy.hasDtype()) + return failure(); + + Type resultETy = resultTy.getDtype(); + if (!isa(resultETy)) + return failure(); + + Value inputScale, inputZeroPoint; + Type definingOpInputType; + if (auto defining = input.template getDefiningOp< + Aten_MakePerTensorQuantizedTensorOp>()) { + inputScale = defining.getScale(); + inputZeroPoint = defining.getZeroPoint(); + definingOpInputType = defining.getSelf().getType(); + } + + auto inputIntReprType = + dyn_cast_or_null(definingOpInputType); + if (!inputScale || !inputZeroPoint || !inputIntReprType || + !inputIntReprType.hasDtype()) + return failure(); + auto intReprDtype = inputIntReprType.getDtype(); + + // set SrcOp type to use quantized dtype from input + auto newResultTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto newResult = rewriter.create(op.getLoc(), newResultTy, operands); + + // int repr to get non quantized int type result + auto intReprTy = rewriter.getType( + resultTy.getOptionalSizes(), intReprDtype); + auto intRepr = + rewriter.create(op.getLoc(), intReprTy, newResult); + + // requantize so the scale and zero-point info can be attached + auto quantTy = + rewriter.getType(resultTy.getOptionalSizes(), qDtype); + auto quant = rewriter.create( + op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); + + // dequant back to original dtype + auto dequant = + rewriter.create(op.getLoc(), resultTy, quant); + rewriter.replaceOp(op, dequant); + return success(); + } +}; + template class RemoveUnused : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -285,11 +357,12 @@ public: RemoveUnused, RemoveUnused, RemoveUnused, QuantizeOperands, - QuantizeOperands, + QuantizeOperands, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, QuantizeOperands, QuantizeTransposedOperands, QuantizeAccumulator, - QuantizeBias>(context); + QuantizeResultLikeOperand, QuantizeBias>( + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ccef6e106..93269b065 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -331,6 +331,9 @@ TORCHDYNAMO_XFAIL_SET = { "AtenMatmulQint8VV_basic", "AtenMatmulQint8VM_basic", "AtenMatmulQint8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", # Dynamo not supporting conv_tbc @@ -413,6 +416,9 @@ FX_IMPORTER_XFAIL_SET = { 'AtenMmQMixedSigni8_basic', 'AtenMmQint8_basic', 'AtenMmQuint8_basic', + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", 'AtenSubFloatModule_basic', 'BincountMinlengthModule_basic', 'BincountModule_basic', @@ -2466,6 +2472,9 @@ ONNX_XFAIL_SET = { "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "QuantizedReluInt8_basic", + "QuantizedReluInt32_basic", + "QuantizedReluUint8_basic", "RandIntDtypeModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 5010c8b99..b365ac54f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -705,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils): # ============================================================================== +class QuantizedReluInt8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt8()) +def QuantizedReluInt8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class QuantizedReluUint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluUint8()) +def QuantizedReluUint8_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8)) + +# ============================================================================== + +class QuantizedReluInt32(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190) + qx = torch.dequantize(qx) + return torch.relu(qx) + +@register_test_case(module_factory=lambda: QuantizedReluInt32()) +def QuantizedReluInt32_basic(module, tu: TestUtils): + module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32)) + +# ============================================================================== + class ElementwiseRelu6Module(torch.nn.Module):