mirror of https://github.com/llvm/torch-mlir
Add Brevitas quantization for per tensor and per channel modes
parent
c29c07b29b
commit
23ced3dca6
|
@ -10635,6 +10635,84 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$scale,
|
||||
Torch_IntType:$zero_point,
|
||||
Torch_IntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scales,
|
||||
AnyTorchTensorType:$zero_points,
|
||||
Torch_IntType:$axis,
|
||||
Torch_IntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenIntReprOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -474,6 +475,109 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenQuantizePerChannelOp
|
||||
: public OpConversionPattern<AtenQuantizePerChannelOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenQuantizePerChannelOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
Type inputDtype =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||
if (!inputDtype.isa<mlir::FloatType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor must be floating point dtype for quantization");
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type outputDtype = newResultType.cast<RankedTensorType>().getElementType();
|
||||
if (!outputDtype.isa<mlir::IntegerType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "target quantization type must be an integer type");
|
||||
}
|
||||
|
||||
int64_t axis;
|
||||
if (!matchPattern(op.getAxis(), m_TorchConstantInt(&axis)))
|
||||
return rewriter.notifyMatchFailure(op, "only constant axis supported");
|
||||
|
||||
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||
// dims won't be used.
|
||||
Value initTensor = createZeroInitTensor(
|
||||
rewriter, loc, getTensorSizes(rewriter, loc, self), outputDtype);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
selfRank, utils::IteratorType::parallel);
|
||||
|
||||
AffineMap quantAxisMap =
|
||||
AffineMap::get(selfRank, 0, rewriter.getAffineDimExpr(axis), context);
|
||||
AffineMap identity = AffineMap::getMultiDimIdentityMap(selfRank, context);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps{identity, quantAxisMap, quantAxisMap,
|
||||
identity};
|
||||
|
||||
Value quantized =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, newResultType,
|
||||
ValueRange{self, adaptor.getScales(), adaptor.getZeroPoints()},
|
||||
initTensor, indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value lhs = convertScalarToDtype(b, loc, args[0], inputDtype);
|
||||
Value scale =
|
||||
convertScalarToDtype(b, loc, args[1], inputDtype);
|
||||
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
|
||||
Value rounded = b.create<math::RoundOp>(loc, scaled);
|
||||
Type intermediateDtype = b.getIntegerType(
|
||||
cast<mlir::FloatType>(inputDtype).getWidth());
|
||||
Value fpToI =
|
||||
convertScalarToDtype(b, loc, rounded, intermediateDtype);
|
||||
|
||||
Value zeroPoint =
|
||||
convertScalarToDtype(b, loc, args[2], intermediateDtype);
|
||||
Value shifted =
|
||||
b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
|
||||
|
||||
Value quantMin = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getIntegerAttr(
|
||||
intermediateDtype,
|
||||
llvm::minIntN(cast<mlir::IntegerType>(outputDtype)
|
||||
.getWidth())));
|
||||
Value quantMax = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getIntegerAttr(
|
||||
intermediateDtype,
|
||||
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype)
|
||||
.getWidth())));
|
||||
Value minCompare = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shifted, quantMin);
|
||||
Value minClamp = b.create<arith::SelectOp>(loc, minCompare,
|
||||
quantMin, shifted);
|
||||
|
||||
Value maxCompare = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, minClamp, quantMax);
|
||||
Value maxClamp = b.create<arith::SelectOp>(
|
||||
loc, maxCompare, quantMax, minClamp);
|
||||
|
||||
Value truncated =
|
||||
convertScalarToDtype(b, loc, maxClamp, outputDtype);
|
||||
b.create<linalg::YieldOp>(loc, truncated);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, quantized);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||
public:
|
||||
|
@ -493,8 +597,8 @@ public:
|
|||
|
||||
Type elementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!elementType.isa<mlir::FloatType>())
|
||||
return op.emitError("unimplemented: non-floating point type");
|
||||
// if (!elementType.isa<mlir::FloatType>())
|
||||
// return op.emitError("unimplemented: non-floating point type");
|
||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
size_t numSpacialDims = inRank - 2;
|
||||
if (numSpacialDims != 2)
|
||||
|
@ -661,21 +765,25 @@ public:
|
|||
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
||||
}
|
||||
|
||||
Type outputElementType = elementType;
|
||||
if (outputElementType.isInteger(8))
|
||||
outputElementType = rewriter.getIntegerType(32);
|
||||
|
||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(outDims), elementType);
|
||||
loc, getAsOpFoldResult(outDims), outputElementType);
|
||||
|
||||
Value bias = adaptor.getBias();
|
||||
Value outputTensor;
|
||||
if (bias.getType().isa<Torch::NoneType>()) {
|
||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
||||
.getResult(0);
|
||||
Value c0Val = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(outputElementType));
|
||||
outputTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, c0Val, initTensor).getResult(0);
|
||||
} else {
|
||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||
if (biasType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||
if (elementType != biasType.getElementType())
|
||||
if (outputElementType != biasType.getElementType())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||
|
||||
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
||||
|
@ -840,6 +948,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
|||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenFlipOp>();
|
||||
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenQuantizePerChannelOp>();
|
||||
patterns.add<ConvertAtenQuantizePerChannelOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenMatmulOp>();
|
||||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBmmOp>();
|
||||
|
|
|
@ -45,8 +45,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
|||
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
|
||||
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
|
||||
}
|
||||
llvm_unreachable("Unhandled element type for comparison");
|
||||
}
|
||||
|
@ -798,10 +797,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(clamp.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
clamp.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
AtenClampOp::Adaptor adaptor(operands);
|
||||
auto min = adaptor.getMin();
|
||||
auto max = adaptor.getMax();
|
||||
|
@ -813,14 +808,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
auto result = payloadArgs[0];
|
||||
if (!min.getType().isa<Torch::NoneType>()) {
|
||||
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
result, minPromoted);
|
||||
auto pred = createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||
arith::CmpIPredicate::ult,
|
||||
arith::CmpIPredicate::slt>(
|
||||
b, loc, dtype, result, minPromoted);
|
||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||
}
|
||||
if (!max.getType().isa<Torch::NoneType>()) {
|
||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
result, maxPromoted);
|
||||
auto pred = createComparisonTemplate<arith::CmpFPredicate::UGT,
|
||||
arith::CmpIPredicate::ugt,
|
||||
arith::CmpIPredicate::sgt>(
|
||||
b, loc, dtype, result, maxPromoted);
|
||||
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
|
||||
}
|
||||
return result;
|
||||
|
@ -995,6 +994,65 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
}
|
||||
|
||||
if (auto quantizePerTensor = dyn_cast<AtenQuantizePerTensorOp>(op)) {
|
||||
Type inputDtype =
|
||||
converter->convertType(quantizePerTensor.getSelf().getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!inputDtype.isa<mlir::FloatType>()) {
|
||||
quantizePerTensor.emitError(
|
||||
"input tensor must be floating point dtype for quantization");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type outputDtype =
|
||||
converter->convertType(quantizePerTensor.getResult().getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!outputDtype.isa<mlir::IntegerType>()) {
|
||||
quantizePerTensor.emitError(
|
||||
"target quantization type must be an integer type");
|
||||
return nullptr;
|
||||
}
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], inputDtype);
|
||||
Value scale = convertScalarToDtype(b, loc, operands[1], inputDtype);
|
||||
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
|
||||
Value rounded = b.create<math::RoundOp>(loc, scaled);
|
||||
Type intermediateDtype =
|
||||
b.getIntegerType(cast<mlir::FloatType>(inputDtype).getWidth());
|
||||
Value fpToI = convertScalarToDtype(b, loc, rounded, intermediateDtype);
|
||||
|
||||
Value zeroPoint =
|
||||
convertScalarToDtype(b, loc, operands[2], intermediateDtype);
|
||||
Value shifted = b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
|
||||
|
||||
Value quantMin = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getIntegerAttr(
|
||||
intermediateDtype,
|
||||
llvm::minIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
|
||||
Value quantMax = b.create<arith::ConstantOp>(
|
||||
loc,
|
||||
b.getIntegerAttr(
|
||||
intermediateDtype,
|
||||
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
|
||||
Value minCompare = createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||
arith::CmpIPredicate::ult,
|
||||
arith::CmpIPredicate::slt>(
|
||||
b, loc, intermediateDtype, shifted, quantMin);
|
||||
Value minClamp =
|
||||
b.create<arith::SelectOp>(loc, minCompare, quantMin, shifted);
|
||||
|
||||
Value maxCompare = createComparisonTemplate<arith::CmpFPredicate::UGT,
|
||||
arith::CmpIPredicate::ugt,
|
||||
arith::CmpIPredicate::sgt>(
|
||||
b, loc, intermediateDtype, minClamp, quantMax);
|
||||
Value maxClamp =
|
||||
b.create<arith::SelectOp>(loc, maxCompare, quantMax, minClamp);
|
||||
|
||||
return convertScalarToDtype(b, loc, maxClamp, outputDtype);
|
||||
}
|
||||
|
||||
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
||||
// Check if the rank of the input tensor is valid.
|
||||
AtenTriuOp::Adaptor adaptor(operands);
|
||||
|
@ -1088,7 +1146,8 @@ public:
|
|||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenQuantizePerTensorOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1578,4 +1637,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
||||
target.addIllegalOp<TensorStaticInfoCastOp>();
|
||||
|
||||
// Quantization
|
||||
target.addIllegalOp<AtenQuantizePerTensorOp>();
|
||||
}
|
||||
|
|
|
@ -94,8 +94,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
|||
SmallVector<int64_t>(inRank, kUnknownSize))),
|
||||
elementType);
|
||||
|
||||
Value cf0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
||||
Value cf0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||
SmallVector<OpFoldResult> paddingValues =
|
||||
getAsOpFoldResult(paddingIncludingUnchanged);
|
||||
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
|
||||
|
|
|
@ -263,7 +263,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
|
||||
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
||||
// dtype.
|
||||
if (isByteOrChar(dtype)) {
|
||||
if (isByteOrChar(dtype) && !scalarType.isa<mlir::IntegerType>()) {
|
||||
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
||||
"convertScalarToDtype "
|
||||
<< scalarType << "(scalar type) -> " << dtype
|
||||
|
|
|
@ -29,6 +29,17 @@ using namespace mlir::torch::Torch;
|
|||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static inline torch_upstream::ScalarType
|
||||
materializeQType(torch_upstream::ScalarType t) {
|
||||
if (t == torch_upstream::ScalarType::QInt8)
|
||||
return torch_upstream::ScalarType::Char;
|
||||
if (t == torch_upstream::ScalarType::QUInt8)
|
||||
return torch_upstream::ScalarType::Char;
|
||||
if (t == torch_upstream::ScalarType::QInt32)
|
||||
return torch_upstream::ScalarType::Char;
|
||||
return t;
|
||||
}
|
||||
|
||||
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||
Location loc, Value value,
|
||||
Type desiredType,
|
||||
|
@ -1393,6 +1404,167 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenQuantizePerTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void transposeOutputChannels(PatternRewriter &rewriter, Location loc,
|
||||
Operation *target, Value other) {
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
AtenTransposeIntOp transposed = rewriter.create<AtenTransposeIntOp>(
|
||||
loc,
|
||||
Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||
target->getContext()),
|
||||
other, zero, one);
|
||||
transposed->moveBefore(target);
|
||||
target->replaceUsesOfWith(other, transposed.getResult());
|
||||
}
|
||||
|
||||
template <typename QuantizationOp>
|
||||
static LogicalResult commuteQuantizedConvolution(QuantizationOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
auto result = op.getResult();
|
||||
if (!result.hasOneUse())
|
||||
return rewriter.notifyMatchFailure(op, "quantize op has multiple uses");
|
||||
|
||||
if (auto clampOp = dyn_cast<AtenClampOp>(*result.getUsers().begin())) {
|
||||
result = clampOp.getResult();
|
||||
}
|
||||
|
||||
AtenSubTensorOp subZeroPoint;
|
||||
if (!(subZeroPoint = dyn_cast<AtenSubTensorOp>(*result.getUsers().begin()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize op does not have sub tensor as user");
|
||||
}
|
||||
auto subResult = subZeroPoint.getResult();
|
||||
|
||||
AtenMulTensorOp mulScale;
|
||||
if (!(mulScale = dyn_cast<AtenMulTensorOp>(*subResult.getUsers().begin()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize op does not have mul tensor in chain");
|
||||
}
|
||||
auto mulResult = mulScale.getResult();
|
||||
|
||||
Aten_ConvolutionOp conv;
|
||||
if (!(conv = dyn_cast<Aten_ConvolutionOp>(*mulResult.getUsers().begin()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize op does not have convolution in chain");
|
||||
}
|
||||
|
||||
auto convResult = conv.getResult();
|
||||
conv->replaceUsesOfWith(mulResult, result);
|
||||
convResult.replaceAllUsesWith(mulResult);
|
||||
subZeroPoint->replaceUsesOfWith(result, convResult);
|
||||
subZeroPoint->moveAfter(conv);
|
||||
mulScale->moveAfter(subZeroPoint);
|
||||
|
||||
if (isa<AtenQuantizePerChannelOp>(op)) {
|
||||
Value other;
|
||||
if (subZeroPoint.getSelf() == convResult) {
|
||||
other = subZeroPoint.getOther();
|
||||
} else {
|
||||
other = subZeroPoint.getSelf();
|
||||
}
|
||||
Location otherLoc = conv->getLoc();
|
||||
transposeOutputChannels(rewriter, otherLoc, subZeroPoint, other);
|
||||
|
||||
if (mulScale.getSelf() == subResult) {
|
||||
other = mulScale.getOther();
|
||||
} else {
|
||||
other = mulScale.getSelf();
|
||||
}
|
||||
transposeOutputChannels(rewriter, otherLoc, mulScale, other);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void AtenQuantizePerTensorOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
|
||||
auto loc = op.getLoc();
|
||||
auto result = op.getResult();
|
||||
if (!result.hasOneUse())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize per tensor op has multiple uses");
|
||||
|
||||
AtenIntReprOp reprOp;
|
||||
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize per tensor op use must be aten.int_repr");
|
||||
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
|
||||
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
|
||||
|
||||
auto resultType = result.getType();
|
||||
auto quantized = rewriter
|
||||
.create<AtenQuantizePerTensorOp>(
|
||||
loc, resultType, op.getSelf(), op.getScale(),
|
||||
op.getZeroPoint(), dtypeValue)
|
||||
.getResult();
|
||||
|
||||
reprOp.getResult().replaceAllUsesWith(quantized);
|
||||
rewriter.eraseOp(reprOp);
|
||||
rewriter.replaceOp(op, quantized);
|
||||
return success();
|
||||
});
|
||||
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
|
||||
return commuteQuantizedConvolution<AtenQuantizePerTensorOp>(op, rewriter);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenQuantizePerChannelOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenQuantizePerChannelOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
|
||||
auto loc = op.getLoc();
|
||||
auto result = op.getResult();
|
||||
if (!result.hasOneUse())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize per channel op has multiple uses");
|
||||
|
||||
AtenIntReprOp reprOp;
|
||||
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "quantize per channel op use must be aten.int_repr");
|
||||
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
|
||||
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
|
||||
|
||||
auto resultType = result.getType();
|
||||
auto quantized = rewriter
|
||||
.create<AtenQuantizePerChannelOp>(
|
||||
loc, resultType, op.getSelf(), op.getScales(),
|
||||
op.getZeroPoints(), op.getAxis(), dtypeValue)
|
||||
.getResult();
|
||||
|
||||
reprOp.getResult().replaceAllUsesWith(quantized);
|
||||
rewriter.eraseOp(reprOp);
|
||||
rewriter.replaceOp(op, quantized);
|
||||
return success();
|
||||
});
|
||||
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
|
||||
return commuteQuantizedConvolution<AtenQuantizePerChannelOp>(op, rewriter);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NonValueTensorLiteralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -7446,6 +7446,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
|
||||
" return %arg4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int) -> !torch.int {\n"
|
||||
" return %arg7 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||
|
|
|
@ -615,6 +615,17 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
|
|||
/*skipRankCheck=*/true);
|
||||
}
|
||||
|
||||
static Type getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
|
||||
MLIRContext *context, ArrayRef<const ValueKnowledge *> tensors) {
|
||||
auto promotedType =
|
||||
getPromotedResultTypeAssumingNonZeroRank(context, tensors);
|
||||
if (promotedType.isSignedInteger(8))
|
||||
return mlir::IntegerType::get(context, 32, IntegerType::Signed);
|
||||
if (promotedType.isUnsignedInteger(8))
|
||||
return mlir::IntegerType::get(context, 32, IntegerType::Unsigned);
|
||||
return promotedType;
|
||||
}
|
||||
|
||||
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
||||
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
||||
assert(!inputDType ||
|
||||
|
@ -733,8 +744,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
AtenMseLossOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
||||
knowledge.dtype =
|
||||
getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
|
||||
op->getContext(),
|
||||
{&operands[0]->getValue(), &operands[1]->getValue()});
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1034,6 +1034,18 @@ def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N
|
|||
else:
|
||||
assert False, "Unsupported dtype"
|
||||
|
||||
def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇quantize_per_tensor〡dtype(self_rank: int, self_dtype: int, scale: float, zero_point: int, dtype: int) -> int:
|
||||
return dtype
|
||||
|
||||
def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇quantize_per_channel〡dtype(self_rank: int, self_dtype: int, scales_rank: int, scales_dtype: int, zero_points_rank: int, zero_points_dtype: int, axis: int, dtype: int) -> int:
|
||||
return dtype
|
||||
|
||||
class DummyClassType:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
|
@ -650,6 +650,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||
|
||||
# quantization ops
|
||||
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
# `prim::` namespace.
|
||||
# ==========================================================================
|
||||
|
|
Loading…
Reference in New Issue