Add Brevitas quantization for per tensor and per channel modes

quinn/quantization oneshot-20230124.78
Quinn Dawkins 2023-01-16 01:17:29 -05:00 committed by Quinn Dawkins
parent c29c07b29b
commit 23ced3dca6
10 changed files with 487 additions and 24 deletions

View File

@ -10635,6 +10635,84 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$scale,
Torch_IntType:$zero_point,
Torch_IntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scales,
AnyTorchTensorType:$zero_points,
Torch_IntType:$axis,
Torch_IntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIntReprOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -474,6 +475,109 @@ public:
};
} // namespace
namespace {
class ConvertAtenQuantizePerChannelOp
: public OpConversionPattern<AtenQuantizePerChannelOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenQuantizePerChannelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
Type inputDtype =
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
if (!inputDtype.isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "input tensor must be floating point dtype for quantization");
}
Type newResultType = getTypeConverter()->convertType(op.getType());
Type outputDtype = newResultType.cast<RankedTensorType>().getElementType();
if (!outputDtype.isa<mlir::IntegerType>()) {
return rewriter.notifyMatchFailure(
op, "target quantization type must be an integer type");
}
int64_t axis;
if (!matchPattern(op.getAxis(), m_TorchConstantInt(&axis)))
return rewriter.notifyMatchFailure(op, "only constant axis supported");
// Only used to calculate flipped values, i.e. those on the flip axes. Other
// dims won't be used.
Value initTensor = createZeroInitTensor(
rewriter, loc, getTensorSizes(rewriter, loc, self), outputDtype);
SmallVector<utils::IteratorType> iteratorTypes(
selfRank, utils::IteratorType::parallel);
AffineMap quantAxisMap =
AffineMap::get(selfRank, 0, rewriter.getAffineDimExpr(axis), context);
AffineMap identity = AffineMap::getMultiDimIdentityMap(selfRank, context);
SmallVector<AffineMap> indexingMaps{identity, quantAxisMap, quantAxisMap,
identity};
Value quantized =
rewriter
.create<linalg::GenericOp>(
loc, newResultType,
ValueRange{self, adaptor.getScales(), adaptor.getZeroPoints()},
initTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value lhs = convertScalarToDtype(b, loc, args[0], inputDtype);
Value scale =
convertScalarToDtype(b, loc, args[1], inputDtype);
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
Value rounded = b.create<math::RoundOp>(loc, scaled);
Type intermediateDtype = b.getIntegerType(
cast<mlir::FloatType>(inputDtype).getWidth());
Value fpToI =
convertScalarToDtype(b, loc, rounded, intermediateDtype);
Value zeroPoint =
convertScalarToDtype(b, loc, args[2], intermediateDtype);
Value shifted =
b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
Value quantMin = b.create<arith::ConstantOp>(
loc,
b.getIntegerAttr(
intermediateDtype,
llvm::minIntN(cast<mlir::IntegerType>(outputDtype)
.getWidth())));
Value quantMax = b.create<arith::ConstantOp>(
loc,
b.getIntegerAttr(
intermediateDtype,
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype)
.getWidth())));
Value minCompare = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shifted, quantMin);
Value minClamp = b.create<arith::SelectOp>(loc, minCompare,
quantMin, shifted);
Value maxCompare = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, minClamp, quantMax);
Value maxClamp = b.create<arith::SelectOp>(
loc, maxCompare, quantMax, minClamp);
Value truncated =
convertScalarToDtype(b, loc, maxClamp, outputDtype);
b.create<linalg::YieldOp>(loc, truncated);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, quantized);
return success();
}
};
} // namespace
namespace {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
@ -493,8 +597,8 @@ public:
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
return op.emitError("unimplemented: non-floating point type");
// if (!elementType.isa<mlir::FloatType>())
// return op.emitError("unimplemented: non-floating point type");
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
size_t numSpacialDims = inRank - 2;
if (numSpacialDims != 2)
@ -661,21 +765,25 @@ public:
castIndexToInt(weightDims[i]), strideIntValues[i]));
}
Type outputElementType = elementType;
if (outputElementType.isInteger(8))
outputElementType = rewriter.getIntegerType(32);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outDims), elementType);
loc, getAsOpFoldResult(outDims), outputElementType);
Value bias = adaptor.getBias();
Value outputTensor;
if (bias.getType().isa<Torch::NoneType>()) {
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
Value c0Val = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(outputElementType));
outputTensor =
rewriter.create<linalg::FillOp>(loc, c0Val, initTensor).getResult(0);
} else {
auto biasType = bias.getType().cast<RankedTensorType>();
if (biasType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
if (elementType != biasType.getElementType())
if (outputElementType != biasType.getElementType())
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
@ -840,6 +948,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenMmOp>(typeConverter, context);
target.addIllegalOp<AtenFlipOp>();
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
target.addIllegalOp<AtenQuantizePerChannelOp>();
patterns.add<ConvertAtenQuantizePerChannelOp>(typeConverter, context);
target.addIllegalOp<AtenMatmulOp>();
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();

View File

@ -45,8 +45,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
}
llvm_unreachable("Unhandled element type for comparison");
}
@ -798,10 +797,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type dtype = converter->convertType(clamp.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
clamp.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
AtenClampOp::Adaptor adaptor(operands);
auto min = adaptor.getMin();
auto max = adaptor.getMax();
@ -813,14 +808,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto result = payloadArgs[0];
if (!min.getType().isa<Torch::NoneType>()) {
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
result, minPromoted);
auto pred = createComparisonTemplate<arith::CmpFPredicate::ULT,
arith::CmpIPredicate::ult,
arith::CmpIPredicate::slt>(
b, loc, dtype, result, minPromoted);
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
}
if (!max.getType().isa<Torch::NoneType>()) {
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
result, maxPromoted);
auto pred = createComparisonTemplate<arith::CmpFPredicate::UGT,
arith::CmpIPredicate::ugt,
arith::CmpIPredicate::sgt>(
b, loc, dtype, result, maxPromoted);
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
}
return result;
@ -995,6 +994,65 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
}
if (auto quantizePerTensor = dyn_cast<AtenQuantizePerTensorOp>(op)) {
Type inputDtype =
converter->convertType(quantizePerTensor.getSelf().getType())
.cast<RankedTensorType>()
.getElementType();
if (!inputDtype.isa<mlir::FloatType>()) {
quantizePerTensor.emitError(
"input tensor must be floating point dtype for quantization");
return nullptr;
}
Type outputDtype =
converter->convertType(quantizePerTensor.getResult().getType())
.cast<RankedTensorType>()
.getElementType();
if (!outputDtype.isa<mlir::IntegerType>()) {
quantizePerTensor.emitError(
"target quantization type must be an integer type");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], inputDtype);
Value scale = convertScalarToDtype(b, loc, operands[1], inputDtype);
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
Value rounded = b.create<math::RoundOp>(loc, scaled);
Type intermediateDtype =
b.getIntegerType(cast<mlir::FloatType>(inputDtype).getWidth());
Value fpToI = convertScalarToDtype(b, loc, rounded, intermediateDtype);
Value zeroPoint =
convertScalarToDtype(b, loc, operands[2], intermediateDtype);
Value shifted = b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
Value quantMin = b.create<arith::ConstantOp>(
loc,
b.getIntegerAttr(
intermediateDtype,
llvm::minIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
Value quantMax = b.create<arith::ConstantOp>(
loc,
b.getIntegerAttr(
intermediateDtype,
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
Value minCompare = createComparisonTemplate<arith::CmpFPredicate::ULT,
arith::CmpIPredicate::ult,
arith::CmpIPredicate::slt>(
b, loc, intermediateDtype, shifted, quantMin);
Value minClamp =
b.create<arith::SelectOp>(loc, minCompare, quantMin, shifted);
Value maxCompare = createComparisonTemplate<arith::CmpFPredicate::UGT,
arith::CmpIPredicate::ugt,
arith::CmpIPredicate::sgt>(
b, loc, intermediateDtype, minClamp, quantMax);
Value maxClamp =
b.create<arith::SelectOp>(loc, maxCompare, quantMax, minClamp);
return convertScalarToDtype(b, loc, maxClamp, outputDtype);
}
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
// Check if the rank of the input tensor is valid.
AtenTriuOp::Adaptor adaptor(operands);
@ -1088,7 +1146,8 @@ public:
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenQuantizePerTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1578,4 +1637,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
target.addIllegalOp<TensorStaticInfoCastOp>();
// Quantization
target.addIllegalOp<AtenQuantizePerTensorOp>();
}

View File

@ -94,8 +94,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
SmallVector<int64_t>(inRank, kUnknownSize))),
elementType);
Value cf0 =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
Value cf0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
SmallVector<OpFoldResult> paddingValues =
getAsOpFoldResult(paddingIncludingUnchanged);
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,

View File

@ -263,7 +263,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
// We only support conversion from Byte or Char scalarType not to Byte or Char
// dtype.
if (isByteOrChar(dtype)) {
if (isByteOrChar(dtype) && !scalarType.isa<mlir::IntegerType>()) {
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
"convertScalarToDtype "
<< scalarType << "(scalar type) -> " << dtype

View File

@ -29,6 +29,17 @@ using namespace mlir::torch::Torch;
// Utilities
//===----------------------------------------------------------------------===//
static inline torch_upstream::ScalarType
materializeQType(torch_upstream::ScalarType t) {
if (t == torch_upstream::ScalarType::QInt8)
return torch_upstream::ScalarType::Char;
if (t == torch_upstream::ScalarType::QUInt8)
return torch_upstream::ScalarType::Char;
if (t == torch_upstream::ScalarType::QInt32)
return torch_upstream::ScalarType::Char;
return t;
}
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
Location loc, Value value,
Type desiredType,
@ -1393,6 +1404,167 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}
//===----------------------------------------------------------------------===//
// AtenQuantizePerTensorOp
//===----------------------------------------------------------------------===//
static void transposeOutputChannels(PatternRewriter &rewriter, Location loc,
Operation *target, Value other) {
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
AtenTransposeIntOp transposed = rewriter.create<AtenTransposeIntOp>(
loc,
Torch::NonValueTensorType::getWithLeastStaticInformation(
target->getContext()),
other, zero, one);
transposed->moveBefore(target);
target->replaceUsesOfWith(other, transposed.getResult());
}
template <typename QuantizationOp>
static LogicalResult commuteQuantizedConvolution(QuantizationOp op,
PatternRewriter &rewriter) {
auto result = op.getResult();
if (!result.hasOneUse())
return rewriter.notifyMatchFailure(op, "quantize op has multiple uses");
if (auto clampOp = dyn_cast<AtenClampOp>(*result.getUsers().begin())) {
result = clampOp.getResult();
}
AtenSubTensorOp subZeroPoint;
if (!(subZeroPoint = dyn_cast<AtenSubTensorOp>(*result.getUsers().begin()))) {
return rewriter.notifyMatchFailure(
op, "quantize op does not have sub tensor as user");
}
auto subResult = subZeroPoint.getResult();
AtenMulTensorOp mulScale;
if (!(mulScale = dyn_cast<AtenMulTensorOp>(*subResult.getUsers().begin()))) {
return rewriter.notifyMatchFailure(
op, "quantize op does not have mul tensor in chain");
}
auto mulResult = mulScale.getResult();
Aten_ConvolutionOp conv;
if (!(conv = dyn_cast<Aten_ConvolutionOp>(*mulResult.getUsers().begin()))) {
return rewriter.notifyMatchFailure(
op, "quantize op does not have convolution in chain");
}
auto convResult = conv.getResult();
conv->replaceUsesOfWith(mulResult, result);
convResult.replaceAllUsesWith(mulResult);
subZeroPoint->replaceUsesOfWith(result, convResult);
subZeroPoint->moveAfter(conv);
mulScale->moveAfter(subZeroPoint);
if (isa<AtenQuantizePerChannelOp>(op)) {
Value other;
if (subZeroPoint.getSelf() == convResult) {
other = subZeroPoint.getOther();
} else {
other = subZeroPoint.getSelf();
}
Location otherLoc = conv->getLoc();
transposeOutputChannels(rewriter, otherLoc, subZeroPoint, other);
if (mulScale.getSelf() == subResult) {
other = mulScale.getOther();
} else {
other = mulScale.getSelf();
}
transposeOutputChannels(rewriter, otherLoc, mulScale, other);
}
return success();
}
void AtenQuantizePerTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
auto loc = op.getLoc();
auto result = op.getResult();
if (!result.hasOneUse())
return rewriter.notifyMatchFailure(
op, "quantize per tensor op has multiple uses");
AtenIntReprOp reprOp;
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
return rewriter.notifyMatchFailure(
op, "quantize per tensor op use must be aten.int_repr");
int64_t dtypeInt;
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
return failure();
}
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
auto resultType = result.getType();
auto quantized = rewriter
.create<AtenQuantizePerTensorOp>(
loc, resultType, op.getSelf(), op.getScale(),
op.getZeroPoint(), dtypeValue)
.getResult();
reprOp.getResult().replaceAllUsesWith(quantized);
rewriter.eraseOp(reprOp);
rewriter.replaceOp(op, quantized);
return success();
});
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
return commuteQuantizedConvolution<AtenQuantizePerTensorOp>(op, rewriter);
});
}
//===----------------------------------------------------------------------===//
// AtenQuantizePerChannelOp
//===----------------------------------------------------------------------===//
void AtenQuantizePerChannelOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
auto loc = op.getLoc();
auto result = op.getResult();
if (!result.hasOneUse())
return rewriter.notifyMatchFailure(
op, "quantize per channel op has multiple uses");
AtenIntReprOp reprOp;
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
return rewriter.notifyMatchFailure(
op, "quantize per channel op use must be aten.int_repr");
int64_t dtypeInt;
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
return failure();
}
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
auto resultType = result.getType();
auto quantized = rewriter
.create<AtenQuantizePerChannelOp>(
loc, resultType, op.getSelf(), op.getScales(),
op.getZeroPoints(), op.getAxis(), dtypeValue)
.getResult();
reprOp.getResult().replaceAllUsesWith(quantized);
rewriter.eraseOp(reprOp);
rewriter.replaceOp(op, quantized);
return success();
});
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
return commuteQuantizedConvolution<AtenQuantizePerChannelOp>(op, rewriter);
});
}
//===----------------------------------------------------------------------===//
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//

View File

@ -7446,6 +7446,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
" return %arg4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int) -> !torch.int {\n"
" return %arg7 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"

View File

@ -615,6 +615,17 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
/*skipRankCheck=*/true);
}
static Type getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
MLIRContext *context, ArrayRef<const ValueKnowledge *> tensors) {
auto promotedType =
getPromotedResultTypeAssumingNonZeroRank(context, tensors);
if (promotedType.isSignedInteger(8))
return mlir::IntegerType::get(context, 32, IntegerType::Signed);
if (promotedType.isUnsignedInteger(8))
return mlir::IntegerType::get(context, 32, IntegerType::Unsigned);
return promotedType;
}
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
assert(!inputDType ||
@ -733,8 +744,10 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenMseLossOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
knowledge.dtype =
getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
op->getContext(),
{&operands[0]->getValue(), &operands[1]->getValue()});
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

View File

@ -1034,6 +1034,18 @@ def atenfft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N
else:
assert False, "Unsupported dtype"
def atenquantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
return self
def atenquantize_per_tensor〡dtype(self_rank: int, self_dtype: int, scale: float, zero_point: int, dtype: int) -> int:
return dtype
def atenquantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
return self
def atenquantize_per_channel〡dtype(self_rank: int, self_dtype: int, scales_rank: int, scales_dtype: int, zero_points_rank: int, zero_points_dtype: int, axis: int, dtype: int) -> int:
return dtype
class DummyClassType:
def __init__(self):
pass

View File

@ -650,6 +650,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
# quantization ops
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)", has_canonicalizer=True)
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)", has_canonicalizer=True)
emit("aten::int_repr : (Tensor) -> (Tensor)")
# ==========================================================================
# `prim::` namespace.
# ==========================================================================