[torch] Adds Quantization Support for `aten.relu` (#3177)

A choice was made to quantize the return type of Relu with a scale and
zero point copied from the input's quantization scheme. With this
choice, the torch-to-linalg conversion of quantized Relu essentially
computes max(input, zeroPoint) in the elementwise payload.
pull/2926/merge
zjgarvey 2024-04-23 13:01:36 -05:00 committed by GitHub
parent 09d42044b4
commit a8ba865fca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 216 additions and 15 deletions

View File

@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
llvm_unreachable("Unhandled element type for comparison");
}
static Value getZeroPoint(Value value) {
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
return make.getZeroPoint();
}
return nullptr;
}
static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OGT,
@ -528,19 +535,68 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
relu.emitError("unimplemented: non-floating point dtype");
Value zeroPoint = getZeroPoint(relu.getSelf());
Value arg = payloadArgs[0];
auto intType = arg.getType().dyn_cast<mlir::IntegerType>();
if (zeroPoint && !intType) {
relu.emitError("unimplemented: non-integer quantized Relu.");
return nullptr;
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
auto reluTorchType = cast<ValueTensorType>(relu.getType());
bool isUnsigned =
torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype());
if (zeroPoint) {
int64_t zeroPointInt;
int64_t width = intType.getWidth();
assert(width < 64);
int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1));
int64_t maxForIntType =
isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1;
// check for constant zero point edge-cases:
if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) {
if (zeroPointInt > maxForIntType) {
// TODO: figure out how to handle this case:
// current impl. quantizes output like input.
// If zero point > maxForIntType, ordinary relu should return 0.
// However, 0 isn't represented in such a quantization scheme.
relu.emitError(
"unimplemented: quantized relu for zero-point > max qint");
return nullptr;
}
if (zeroPointInt < minForIntType)
return arg;
}
zeroPoint = converter->materializeTargetConversion(
b, loc, converter->convertType(zeroPoint.getType()), zeroPoint);
auto minForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType));
auto maxForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType));
auto zpLtMax = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, maxForIntTypeValue);
b.create<cf::AssertOp>(
loc, zpLtMax,
b.getStringAttr("Invalid Quantization: quantized relu with "
"zero-point > max qint"));
auto zpLtMin = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, minForIntTypeValue);
zeroPoint = b.create<arith::SelectOp>(loc, zpLtMin, minForIntTypeValue,
zeroPoint);
zeroPoint = b.create<arith::TruncIOp>(loc, arg.getType(), zeroPoint);
} else {
zeroPoint =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(arg.getType()));
}
Value cmp;
if (intType) {
auto pred =
isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt;
cmp = b.create<arith::CmpIOp>(loc, pred, arg, zeroPoint);
} else {
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, arg,
zeroPoint);
}
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
}
if (auto round = dyn_cast<AtenRoundOp>(op)) {
if (!round.getType()

View File

@ -20,6 +20,13 @@ using namespace mlir::torch::Torch;
namespace {
template <typename SrcOp> struct QuantInfo {
static constexpr unsigned operandsToQuantize[2] = {0, 1};
};
template <> struct QuantInfo<AtenReluOp> {
static constexpr unsigned operandsToQuantize[1] = {0};
};
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
@ -42,8 +49,9 @@ public:
return operand;
};
operands[0] = f(operands[0]);
operands[1] = f(operands[1]);
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
operands[i] = f(operands[i]);
}
if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
@ -259,6 +267,70 @@ public:
}
};
// Use for ops which do not manipulate scale/zero point of an input.
template <typename SrcOp>
class QuantizeResultLikeOperand : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
Value input = operands[0];
auto inputType = dyn_cast_or_null<ValueTensorType>(input.getType());
if (!inputType || !inputType.hasDtype())
return failure();
auto qDtype = inputType.getDtype();
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype())
return failure();
Type resultETy = resultTy.getDtype();
if (!isa<mlir::FloatType>(resultETy))
return failure();
Value inputScale, inputZeroPoint;
Type definingOpInputType;
if (auto defining = input.template getDefiningOp<
Aten_MakePerTensorQuantizedTensorOp>()) {
inputScale = defining.getScale();
inputZeroPoint = defining.getZeroPoint();
definingOpInputType = defining.getSelf().getType();
}
auto inputIntReprType =
dyn_cast_or_null<ValueTensorType>(definingOpInputType);
if (!inputScale || !inputZeroPoint || !inputIntReprType ||
!inputIntReprType.hasDtype())
return failure();
auto intReprDtype = inputIntReprType.getDtype();
// set SrcOp type to use quantized dtype from input
auto newResultTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
auto newResult = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
// int repr to get non quantized int type result
auto intReprTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(), intReprDtype);
auto intRepr =
rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, newResult);
// requantize so the scale and zero-point info can be attached
auto quantTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint);
// dequant back to original dtype
auto dequant =
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
rewriter.replaceOp(op, dequant);
return success();
}
};
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
@ -285,11 +357,12 @@ public:
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMatmulOp>,
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>,
QuantizeTransposedOperands<AtenMatmulOp>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
QuantizeBias<AtenConvolutionOp>>(context);
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context);
GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),

View File

@ -331,6 +331,9 @@ TORCHDYNAMO_XFAIL_SET = {
"AtenMatmulQint8VV_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",
# Dynamo not supporting conv_tbc
@ -413,6 +416,9 @@ FX_IMPORTER_XFAIL_SET = {
'AtenMmQMixedSigni8_basic',
'AtenMmQint8_basic',
'AtenMmQuint8_basic',
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
'AtenSubFloatModule_basic',
'BincountMinlengthModule_basic',
'BincountModule_basic',
@ -2466,6 +2472,9 @@ ONNX_XFAIL_SET = {
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"QuantizedReluInt8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluUint8_basic",
"RandIntDtypeModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",

View File

@ -705,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
# ==============================================================================
class QuantizedReluInt8(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
return torch.relu(qx)
@register_test_case(module_factory=lambda: QuantizedReluInt8())
def QuantizedReluInt8_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8))
# ==============================================================================
class QuantizedReluUint8(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.uint8, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
qx = torch.dequantize(qx)
return torch.relu(qx)
@register_test_case(module_factory=lambda: QuantizedReluUint8())
def QuantizedReluUint8_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8))
# ==============================================================================
class QuantizedReluInt32(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
qx = torch.dequantize(qx)
return torch.relu(qx)
@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32))
# ==============================================================================
class ElementwiseRelu6Module(torch.nn.Module):