mirror of https://github.com/llvm/torch-mlir
[torch] Adds Quantization Support for `aten.relu` (#3177)
A choice was made to quantize the return type of Relu with a scale and zero point copied from the input's quantization scheme. With this choice, the torch-to-linalg conversion of quantized Relu essentially computes max(input, zeroPoint) in the elementwise payload.pull/2926/merge
parent
09d42044b4
commit
a8ba865fca
|
@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
|||
llvm_unreachable("Unhandled element type for comparison");
|
||||
}
|
||||
|
||||
static Value getZeroPoint(Value value) {
|
||||
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
return make.getZeroPoint();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::OGT,
|
||||
|
@ -528,19 +535,68 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
|
||||
}
|
||||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||
if (!relu.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
relu.emitError("unimplemented: non-floating point dtype");
|
||||
Value zeroPoint = getZeroPoint(relu.getSelf());
|
||||
Value arg = payloadArgs[0];
|
||||
auto intType = arg.getType().dyn_cast<mlir::IntegerType>();
|
||||
if (zeroPoint && !intType) {
|
||||
relu.emitError("unimplemented: non-integer quantized Relu.");
|
||||
return nullptr;
|
||||
}
|
||||
Type elementType = payloadArgs[0].getType();
|
||||
Value constZero =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], constZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
|
||||
auto reluTorchType = cast<ValueTensorType>(relu.getType());
|
||||
bool isUnsigned =
|
||||
torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype());
|
||||
if (zeroPoint) {
|
||||
int64_t zeroPointInt;
|
||||
int64_t width = intType.getWidth();
|
||||
assert(width < 64);
|
||||
int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1));
|
||||
int64_t maxForIntType =
|
||||
isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1;
|
||||
// check for constant zero point edge-cases:
|
||||
if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) {
|
||||
if (zeroPointInt > maxForIntType) {
|
||||
// TODO: figure out how to handle this case:
|
||||
// current impl. quantizes output like input.
|
||||
// If zero point > maxForIntType, ordinary relu should return 0.
|
||||
// However, 0 isn't represented in such a quantization scheme.
|
||||
relu.emitError(
|
||||
"unimplemented: quantized relu for zero-point > max qint");
|
||||
return nullptr;
|
||||
}
|
||||
if (zeroPointInt < minForIntType)
|
||||
return arg;
|
||||
}
|
||||
zeroPoint = converter->materializeTargetConversion(
|
||||
b, loc, converter->convertType(zeroPoint.getType()), zeroPoint);
|
||||
auto minForIntTypeValue = b.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType));
|
||||
auto maxForIntTypeValue = b.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType));
|
||||
auto zpLtMax = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
zeroPoint, maxForIntTypeValue);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, zpLtMax,
|
||||
b.getStringAttr("Invalid Quantization: quantized relu with "
|
||||
"zero-point > max qint"));
|
||||
auto zpLtMin = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
zeroPoint, minForIntTypeValue);
|
||||
zeroPoint = b.create<arith::SelectOp>(loc, zpLtMin, minForIntTypeValue,
|
||||
zeroPoint);
|
||||
zeroPoint = b.create<arith::TruncIOp>(loc, arg.getType(), zeroPoint);
|
||||
} else {
|
||||
zeroPoint =
|
||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(arg.getType()));
|
||||
}
|
||||
Value cmp;
|
||||
if (intType) {
|
||||
auto pred =
|
||||
isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt;
|
||||
cmp = b.create<arith::CmpIOp>(loc, pred, arg, zeroPoint);
|
||||
} else {
|
||||
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, arg,
|
||||
zeroPoint);
|
||||
}
|
||||
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
|
||||
}
|
||||
if (auto round = dyn_cast<AtenRoundOp>(op)) {
|
||||
if (!round.getType()
|
||||
|
|
|
@ -20,6 +20,13 @@ using namespace mlir::torch::Torch;
|
|||
|
||||
namespace {
|
||||
|
||||
template <typename SrcOp> struct QuantInfo {
|
||||
static constexpr unsigned operandsToQuantize[2] = {0, 1};
|
||||
};
|
||||
|
||||
template <> struct QuantInfo<AtenReluOp> {
|
||||
static constexpr unsigned operandsToQuantize[1] = {0};
|
||||
};
|
||||
template <typename SrcOp>
|
||||
class QuantizeOperands : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
|
@ -42,8 +49,9 @@ public:
|
|||
return operand;
|
||||
};
|
||||
|
||||
operands[0] = f(operands[0]);
|
||||
operands[1] = f(operands[1]);
|
||||
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
|
||||
operands[i] = f(operands[i]);
|
||||
}
|
||||
|
||||
if (!dequanted) {
|
||||
return rewriter.notifyMatchFailure(op, "no dequantizations found");
|
||||
|
@ -259,6 +267,70 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Use for ops which do not manipulate scale/zero point of an input.
|
||||
template <typename SrcOp>
|
||||
class QuantizeResultLikeOperand : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::SmallVector<Value> operands(op->getOperands());
|
||||
Value input = operands[0];
|
||||
|
||||
auto inputType = dyn_cast_or_null<ValueTensorType>(input.getType());
|
||||
if (!inputType || !inputType.hasDtype())
|
||||
return failure();
|
||||
auto qDtype = inputType.getDtype();
|
||||
|
||||
auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
|
||||
if (!resultTy || !resultTy.hasDtype())
|
||||
return failure();
|
||||
|
||||
Type resultETy = resultTy.getDtype();
|
||||
if (!isa<mlir::FloatType>(resultETy))
|
||||
return failure();
|
||||
|
||||
Value inputScale, inputZeroPoint;
|
||||
Type definingOpInputType;
|
||||
if (auto defining = input.template getDefiningOp<
|
||||
Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
inputScale = defining.getScale();
|
||||
inputZeroPoint = defining.getZeroPoint();
|
||||
definingOpInputType = defining.getSelf().getType();
|
||||
}
|
||||
|
||||
auto inputIntReprType =
|
||||
dyn_cast_or_null<ValueTensorType>(definingOpInputType);
|
||||
if (!inputScale || !inputZeroPoint || !inputIntReprType ||
|
||||
!inputIntReprType.hasDtype())
|
||||
return failure();
|
||||
auto intReprDtype = inputIntReprType.getDtype();
|
||||
|
||||
// set SrcOp type to use quantized dtype from input
|
||||
auto newResultTy =
|
||||
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
|
||||
auto newResult = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);
|
||||
|
||||
// int repr to get non quantized int type result
|
||||
auto intReprTy = rewriter.getType<ValueTensorType>(
|
||||
resultTy.getOptionalSizes(), intReprDtype);
|
||||
auto intRepr =
|
||||
rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, newResult);
|
||||
|
||||
// requantize so the scale and zero-point info can be attached
|
||||
auto quantTy =
|
||||
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
|
||||
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||
op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint);
|
||||
|
||||
// dequant back to original dtype
|
||||
auto dequant =
|
||||
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
|
||||
rewriter.replaceOp(op, dequant);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
@ -285,11 +357,12 @@ public:
|
|||
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
||||
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
|
||||
QuantizeOperands<AtenMatmulOp>,
|
||||
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>,
|
||||
QuantizeTransposedOperands<AtenMatmulOp>,
|
||||
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
|
||||
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
|
||||
QuantizeBias<AtenConvolutionOp>>(context);
|
||||
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||
context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
|
|
|
@ -331,6 +331,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"AtenMatmulQint8VV_basic",
|
||||
"AtenMatmulQint8VM_basic",
|
||||
"AtenMatmulQint8_basic",
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
||||
# Dynamo not supporting conv_tbc
|
||||
|
@ -413,6 +416,9 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
'AtenMmQMixedSigni8_basic',
|
||||
'AtenMmQint8_basic',
|
||||
'AtenMmQuint8_basic',
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
'AtenSubFloatModule_basic',
|
||||
'BincountMinlengthModule_basic',
|
||||
'BincountModule_basic',
|
||||
|
@ -2466,6 +2472,9 @@ ONNX_XFAIL_SET = {
|
|||
"PrimsSqueezeModule_basic",
|
||||
"PrimsViewOfModule_basic",
|
||||
"PrimsViewOfZeroRankModule_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
"RandIntDtypeModule_basic",
|
||||
"RandIntModule_basic",
|
||||
"RandIntPinMemoryModule_basic",
|
||||
|
|
|
@ -705,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class QuantizedReluInt8(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
|
||||
qx = torch.dequantize(qx)
|
||||
return torch.relu(qx)
|
||||
|
||||
@register_test_case(module_factory=lambda: QuantizedReluInt8())
|
||||
def QuantizedReluInt8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class QuantizedReluUint8(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.uint8, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
|
||||
qx = torch.dequantize(qx)
|
||||
return torch.relu(qx)
|
||||
|
||||
@register_test_case(module_factory=lambda: QuantizedReluUint8())
|
||||
def QuantizedReluUint8_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class QuantizedReluInt32(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
|
||||
qx = torch.dequantize(qx)
|
||||
return torch.relu(qx)
|
||||
|
||||
@register_test_case(module_factory=lambda: QuantizedReluInt32())
|
||||
def QuantizedReluInt32_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRelu6Module(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue