mirror of https://github.com/llvm/torch-mlir
Add Brevitas quantization for per tensor and per channel modes
parent
c29c07b29b
commit
23ced3dca6
|
@ -10635,6 +10635,84 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_FloatType:$scale,
|
||||||
|
Torch_IntType:$zero_point,
|
||||||
|
Torch_IntType:$dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$scales,
|
||||||
|
AnyTorchTensorType:$zero_points,
|
||||||
|
Torch_IntType:$axis,
|
||||||
|
Torch_IntType:$dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
|
}
|
||||||
|
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenIntReprOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
@ -474,6 +475,109 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenQuantizePerChannelOp
|
||||||
|
: public OpConversionPattern<AtenQuantizePerChannelOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenQuantizePerChannelOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
auto selfRank =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||||
|
Type inputDtype =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||||
|
if (!inputDtype.isa<mlir::FloatType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input tensor must be floating point dtype for quantization");
|
||||||
|
}
|
||||||
|
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
Type outputDtype = newResultType.cast<RankedTensorType>().getElementType();
|
||||||
|
if (!outputDtype.isa<mlir::IntegerType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "target quantization type must be an integer type");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t axis;
|
||||||
|
if (!matchPattern(op.getAxis(), m_TorchConstantInt(&axis)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "only constant axis supported");
|
||||||
|
|
||||||
|
// Only used to calculate flipped values, i.e. those on the flip axes. Other
|
||||||
|
// dims won't be used.
|
||||||
|
Value initTensor = createZeroInitTensor(
|
||||||
|
rewriter, loc, getTensorSizes(rewriter, loc, self), outputDtype);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
selfRank, utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
AffineMap quantAxisMap =
|
||||||
|
AffineMap::get(selfRank, 0, rewriter.getAffineDimExpr(axis), context);
|
||||||
|
AffineMap identity = AffineMap::getMultiDimIdentityMap(selfRank, context);
|
||||||
|
|
||||||
|
SmallVector<AffineMap> indexingMaps{identity, quantAxisMap, quantAxisMap,
|
||||||
|
identity};
|
||||||
|
|
||||||
|
Value quantized =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, newResultType,
|
||||||
|
ValueRange{self, adaptor.getScales(), adaptor.getZeroPoints()},
|
||||||
|
initTensor, indexingMaps, iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, args[0], inputDtype);
|
||||||
|
Value scale =
|
||||||
|
convertScalarToDtype(b, loc, args[1], inputDtype);
|
||||||
|
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
|
||||||
|
Value rounded = b.create<math::RoundOp>(loc, scaled);
|
||||||
|
Type intermediateDtype = b.getIntegerType(
|
||||||
|
cast<mlir::FloatType>(inputDtype).getWidth());
|
||||||
|
Value fpToI =
|
||||||
|
convertScalarToDtype(b, loc, rounded, intermediateDtype);
|
||||||
|
|
||||||
|
Value zeroPoint =
|
||||||
|
convertScalarToDtype(b, loc, args[2], intermediateDtype);
|
||||||
|
Value shifted =
|
||||||
|
b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
|
||||||
|
|
||||||
|
Value quantMin = b.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
b.getIntegerAttr(
|
||||||
|
intermediateDtype,
|
||||||
|
llvm::minIntN(cast<mlir::IntegerType>(outputDtype)
|
||||||
|
.getWidth())));
|
||||||
|
Value quantMax = b.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
b.getIntegerAttr(
|
||||||
|
intermediateDtype,
|
||||||
|
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype)
|
||||||
|
.getWidth())));
|
||||||
|
Value minCompare = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, shifted, quantMin);
|
||||||
|
Value minClamp = b.create<arith::SelectOp>(loc, minCompare,
|
||||||
|
quantMin, shifted);
|
||||||
|
|
||||||
|
Value maxCompare = b.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sgt, minClamp, quantMax);
|
||||||
|
Value maxClamp = b.create<arith::SelectOp>(
|
||||||
|
loc, maxCompare, quantMax, minClamp);
|
||||||
|
|
||||||
|
Value truncated =
|
||||||
|
convertScalarToDtype(b, loc, maxClamp, outputDtype);
|
||||||
|
b.create<linalg::YieldOp>(loc, truncated);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, quantized);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -493,8 +597,8 @@ public:
|
||||||
|
|
||||||
Type elementType =
|
Type elementType =
|
||||||
input.getType().cast<RankedTensorType>().getElementType();
|
input.getType().cast<RankedTensorType>().getElementType();
|
||||||
if (!elementType.isa<mlir::FloatType>())
|
// if (!elementType.isa<mlir::FloatType>())
|
||||||
return op.emitError("unimplemented: non-floating point type");
|
// return op.emitError("unimplemented: non-floating point type");
|
||||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||||
size_t numSpacialDims = inRank - 2;
|
size_t numSpacialDims = inRank - 2;
|
||||||
if (numSpacialDims != 2)
|
if (numSpacialDims != 2)
|
||||||
|
@ -661,21 +765,25 @@ public:
|
||||||
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
castIndexToInt(weightDims[i]), strideIntValues[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Type outputElementType = elementType;
|
||||||
|
if (outputElementType.isInteger(8))
|
||||||
|
outputElementType = rewriter.getIntegerType(32);
|
||||||
|
|
||||||
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
Value initTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
loc, getAsOpFoldResult(outDims), elementType);
|
loc, getAsOpFoldResult(outDims), outputElementType);
|
||||||
|
|
||||||
Value bias = adaptor.getBias();
|
Value bias = adaptor.getBias();
|
||||||
Value outputTensor;
|
Value outputTensor;
|
||||||
if (bias.getType().isa<Torch::NoneType>()) {
|
if (bias.getType().isa<Torch::NoneType>()) {
|
||||||
Value c0float = rewriter.create<arith::ConstantOp>(
|
Value c0Val = rewriter.create<arith::ConstantOp>(
|
||||||
loc, FloatAttr::get(elementType, 0.0));
|
loc, rewriter.getZeroAttr(outputElementType));
|
||||||
outputTensor = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
|
outputTensor =
|
||||||
.getResult(0);
|
rewriter.create<linalg::FillOp>(loc, c0Val, initTensor).getResult(0);
|
||||||
} else {
|
} else {
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
auto biasType = bias.getType().cast<RankedTensorType>();
|
||||||
if (biasType.getRank() != 1)
|
if (biasType.getRank() != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||||
if (elementType != biasType.getElementType())
|
if (outputElementType != biasType.getElementType())
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
|
||||||
|
|
||||||
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
||||||
|
@ -840,6 +948,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
patterns.add<ConvertAtenMmOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenFlipOp>();
|
target.addIllegalOp<AtenFlipOp>();
|
||||||
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
patterns.add<ConvertAtenFlipOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenQuantizePerChannelOp>();
|
||||||
|
patterns.add<ConvertAtenQuantizePerChannelOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenMatmulOp>();
|
target.addIllegalOp<AtenMatmulOp>();
|
||||||
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenBmmOp>();
|
target.addIllegalOp<AtenBmmOp>();
|
||||||
|
|
|
@ -45,8 +45,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
||||||
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
|
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
|
||||||
if (intType.isUnsigned())
|
if (intType.isUnsigned())
|
||||||
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
||||||
if (intType.isSigned())
|
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
|
||||||
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
|
|
||||||
}
|
}
|
||||||
llvm_unreachable("Unhandled element type for comparison");
|
llvm_unreachable("Unhandled element type for comparison");
|
||||||
}
|
}
|
||||||
|
@ -798,10 +797,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Type dtype = converter->convertType(clamp.getType())
|
Type dtype = converter->convertType(clamp.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!dtype.isa<mlir::FloatType>()) {
|
|
||||||
clamp.emitError("unimplemented: non-floating point dtype");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
AtenClampOp::Adaptor adaptor(operands);
|
AtenClampOp::Adaptor adaptor(operands);
|
||||||
auto min = adaptor.getMin();
|
auto min = adaptor.getMin();
|
||||||
auto max = adaptor.getMax();
|
auto max = adaptor.getMax();
|
||||||
|
@ -813,14 +808,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
auto result = payloadArgs[0];
|
auto result = payloadArgs[0];
|
||||||
if (!min.getType().isa<Torch::NoneType>()) {
|
if (!min.getType().isa<Torch::NoneType>()) {
|
||||||
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
auto minPromoted = convertScalarToDtype(b, loc, min, dtype);
|
||||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
auto pred = createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||||
result, minPromoted);
|
arith::CmpIPredicate::ult,
|
||||||
|
arith::CmpIPredicate::slt>(
|
||||||
|
b, loc, dtype, result, minPromoted);
|
||||||
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
result = b.create<arith::SelectOp>(loc, pred, minPromoted, result);
|
||||||
}
|
}
|
||||||
if (!max.getType().isa<Torch::NoneType>()) {
|
if (!max.getType().isa<Torch::NoneType>()) {
|
||||||
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
auto maxPromoted = convertScalarToDtype(b, loc, max, dtype);
|
||||||
auto pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
auto pred = createComparisonTemplate<arith::CmpFPredicate::UGT,
|
||||||
result, maxPromoted);
|
arith::CmpIPredicate::ugt,
|
||||||
|
arith::CmpIPredicate::sgt>(
|
||||||
|
b, loc, dtype, result, maxPromoted);
|
||||||
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
|
result = b.create<arith::SelectOp>(loc, pred, maxPromoted, result);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -995,6 +994,65 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto quantizePerTensor = dyn_cast<AtenQuantizePerTensorOp>(op)) {
|
||||||
|
Type inputDtype =
|
||||||
|
converter->convertType(quantizePerTensor.getSelf().getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
if (!inputDtype.isa<mlir::FloatType>()) {
|
||||||
|
quantizePerTensor.emitError(
|
||||||
|
"input tensor must be floating point dtype for quantization");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type outputDtype =
|
||||||
|
converter->convertType(quantizePerTensor.getResult().getType())
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getElementType();
|
||||||
|
if (!outputDtype.isa<mlir::IntegerType>()) {
|
||||||
|
quantizePerTensor.emitError(
|
||||||
|
"target quantization type must be an integer type");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], inputDtype);
|
||||||
|
Value scale = convertScalarToDtype(b, loc, operands[1], inputDtype);
|
||||||
|
Value scaled = b.create<arith::DivFOp>(loc, lhs, scale);
|
||||||
|
Value rounded = b.create<math::RoundOp>(loc, scaled);
|
||||||
|
Type intermediateDtype =
|
||||||
|
b.getIntegerType(cast<mlir::FloatType>(inputDtype).getWidth());
|
||||||
|
Value fpToI = convertScalarToDtype(b, loc, rounded, intermediateDtype);
|
||||||
|
|
||||||
|
Value zeroPoint =
|
||||||
|
convertScalarToDtype(b, loc, operands[2], intermediateDtype);
|
||||||
|
Value shifted = b.create<arith::AddIOp>(loc, fpToI, zeroPoint);
|
||||||
|
|
||||||
|
Value quantMin = b.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
b.getIntegerAttr(
|
||||||
|
intermediateDtype,
|
||||||
|
llvm::minIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
|
||||||
|
Value quantMax = b.create<arith::ConstantOp>(
|
||||||
|
loc,
|
||||||
|
b.getIntegerAttr(
|
||||||
|
intermediateDtype,
|
||||||
|
llvm::maxIntN(cast<mlir::IntegerType>(outputDtype).getWidth())));
|
||||||
|
Value minCompare = createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||||
|
arith::CmpIPredicate::ult,
|
||||||
|
arith::CmpIPredicate::slt>(
|
||||||
|
b, loc, intermediateDtype, shifted, quantMin);
|
||||||
|
Value minClamp =
|
||||||
|
b.create<arith::SelectOp>(loc, minCompare, quantMin, shifted);
|
||||||
|
|
||||||
|
Value maxCompare = createComparisonTemplate<arith::CmpFPredicate::UGT,
|
||||||
|
arith::CmpIPredicate::ugt,
|
||||||
|
arith::CmpIPredicate::sgt>(
|
||||||
|
b, loc, intermediateDtype, minClamp, quantMax);
|
||||||
|
Value maxClamp =
|
||||||
|
b.create<arith::SelectOp>(loc, maxCompare, quantMax, minClamp);
|
||||||
|
|
||||||
|
return convertScalarToDtype(b, loc, maxClamp, outputDtype);
|
||||||
|
}
|
||||||
|
|
||||||
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
if (auto triu = dyn_cast<AtenTriuOp>(op)) {
|
||||||
// Check if the rank of the input tensor is valid.
|
// Check if the rank of the input tensor is valid.
|
||||||
AtenTriuOp::Adaptor adaptor(operands);
|
AtenTriuOp::Adaptor adaptor(operands);
|
||||||
|
@ -1088,7 +1146,8 @@ public:
|
||||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||||
|
AtenQuantizePerTensorOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1578,4 +1637,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
|
||||||
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);
|
||||||
target.addIllegalOp<TensorStaticInfoCastOp>();
|
target.addIllegalOp<TensorStaticInfoCastOp>();
|
||||||
|
|
||||||
|
// Quantization
|
||||||
|
target.addIllegalOp<AtenQuantizePerTensorOp>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,8 +94,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||||
SmallVector<int64_t>(inRank, kUnknownSize))),
|
SmallVector<int64_t>(inRank, kUnknownSize))),
|
||||||
elementType);
|
elementType);
|
||||||
|
|
||||||
Value cf0 =
|
Value cf0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
|
||||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
|
||||||
SmallVector<OpFoldResult> paddingValues =
|
SmallVector<OpFoldResult> paddingValues =
|
||||||
getAsOpFoldResult(paddingIncludingUnchanged);
|
getAsOpFoldResult(paddingIncludingUnchanged);
|
||||||
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
|
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
|
||||||
|
|
|
@ -263,7 +263,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
||||||
|
|
||||||
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
// We only support conversion from Byte or Char scalarType not to Byte or Char
|
||||||
// dtype.
|
// dtype.
|
||||||
if (isByteOrChar(dtype)) {
|
if (isByteOrChar(dtype) && !scalarType.isa<mlir::IntegerType>()) {
|
||||||
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
mlir::emitError(loc) << "unsupported: conversion to byte or char type for "
|
||||||
"convertScalarToDtype "
|
"convertScalarToDtype "
|
||||||
<< scalarType << "(scalar type) -> " << dtype
|
<< scalarType << "(scalar type) -> " << dtype
|
||||||
|
|
|
@ -29,6 +29,17 @@ using namespace mlir::torch::Torch;
|
||||||
// Utilities
|
// Utilities
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static inline torch_upstream::ScalarType
|
||||||
|
materializeQType(torch_upstream::ScalarType t) {
|
||||||
|
if (t == torch_upstream::ScalarType::QInt8)
|
||||||
|
return torch_upstream::ScalarType::Char;
|
||||||
|
if (t == torch_upstream::ScalarType::QUInt8)
|
||||||
|
return torch_upstream::ScalarType::Char;
|
||||||
|
if (t == torch_upstream::ScalarType::QInt32)
|
||||||
|
return torch_upstream::ScalarType::Char;
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||||
Location loc, Value value,
|
Location loc, Value value,
|
||||||
Type desiredType,
|
Type desiredType,
|
||||||
|
@ -1393,6 +1404,167 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenQuantizePerTensorOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void transposeOutputChannels(PatternRewriter &rewriter, Location loc,
|
||||||
|
Operation *target, Value other) {
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
AtenTransposeIntOp transposed = rewriter.create<AtenTransposeIntOp>(
|
||||||
|
loc,
|
||||||
|
Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||||
|
target->getContext()),
|
||||||
|
other, zero, one);
|
||||||
|
transposed->moveBefore(target);
|
||||||
|
target->replaceUsesOfWith(other, transposed.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename QuantizationOp>
|
||||||
|
static LogicalResult commuteQuantizedConvolution(QuantizationOp op,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
auto result = op.getResult();
|
||||||
|
if (!result.hasOneUse())
|
||||||
|
return rewriter.notifyMatchFailure(op, "quantize op has multiple uses");
|
||||||
|
|
||||||
|
if (auto clampOp = dyn_cast<AtenClampOp>(*result.getUsers().begin())) {
|
||||||
|
result = clampOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
AtenSubTensorOp subZeroPoint;
|
||||||
|
if (!(subZeroPoint = dyn_cast<AtenSubTensorOp>(*result.getUsers().begin()))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize op does not have sub tensor as user");
|
||||||
|
}
|
||||||
|
auto subResult = subZeroPoint.getResult();
|
||||||
|
|
||||||
|
AtenMulTensorOp mulScale;
|
||||||
|
if (!(mulScale = dyn_cast<AtenMulTensorOp>(*subResult.getUsers().begin()))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize op does not have mul tensor in chain");
|
||||||
|
}
|
||||||
|
auto mulResult = mulScale.getResult();
|
||||||
|
|
||||||
|
Aten_ConvolutionOp conv;
|
||||||
|
if (!(conv = dyn_cast<Aten_ConvolutionOp>(*mulResult.getUsers().begin()))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize op does not have convolution in chain");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto convResult = conv.getResult();
|
||||||
|
conv->replaceUsesOfWith(mulResult, result);
|
||||||
|
convResult.replaceAllUsesWith(mulResult);
|
||||||
|
subZeroPoint->replaceUsesOfWith(result, convResult);
|
||||||
|
subZeroPoint->moveAfter(conv);
|
||||||
|
mulScale->moveAfter(subZeroPoint);
|
||||||
|
|
||||||
|
if (isa<AtenQuantizePerChannelOp>(op)) {
|
||||||
|
Value other;
|
||||||
|
if (subZeroPoint.getSelf() == convResult) {
|
||||||
|
other = subZeroPoint.getOther();
|
||||||
|
} else {
|
||||||
|
other = subZeroPoint.getSelf();
|
||||||
|
}
|
||||||
|
Location otherLoc = conv->getLoc();
|
||||||
|
transposeOutputChannels(rewriter, otherLoc, subZeroPoint, other);
|
||||||
|
|
||||||
|
if (mulScale.getSelf() == subResult) {
|
||||||
|
other = mulScale.getOther();
|
||||||
|
} else {
|
||||||
|
other = mulScale.getSelf();
|
||||||
|
}
|
||||||
|
transposeOutputChannels(rewriter, otherLoc, mulScale, other);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AtenQuantizePerTensorOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto result = op.getResult();
|
||||||
|
if (!result.hasOneUse())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize per tensor op has multiple uses");
|
||||||
|
|
||||||
|
AtenIntReprOp reprOp;
|
||||||
|
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize per tensor op use must be aten.int_repr");
|
||||||
|
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
|
||||||
|
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
|
||||||
|
|
||||||
|
auto resultType = result.getType();
|
||||||
|
auto quantized = rewriter
|
||||||
|
.create<AtenQuantizePerTensorOp>(
|
||||||
|
loc, resultType, op.getSelf(), op.getScale(),
|
||||||
|
op.getZeroPoint(), dtypeValue)
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
reprOp.getResult().replaceAllUsesWith(quantized);
|
||||||
|
rewriter.eraseOp(reprOp);
|
||||||
|
rewriter.replaceOp(op, quantized);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) {
|
||||||
|
return commuteQuantizedConvolution<AtenQuantizePerTensorOp>(op, rewriter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenQuantizePerChannelOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void AtenQuantizePerChannelOp::getCanonicalizationPatterns(
|
||||||
|
RewritePatternSet &patterns, MLIRContext *context) {
|
||||||
|
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto result = op.getResult();
|
||||||
|
if (!result.hasOneUse())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize per channel op has multiple uses");
|
||||||
|
|
||||||
|
AtenIntReprOp reprOp;
|
||||||
|
if (!(reprOp = dyn_cast<AtenIntReprOp>(*result.getUsers().begin())))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "quantize per channel op use must be aten.int_repr");
|
||||||
|
|
||||||
|
int64_t dtypeInt;
|
||||||
|
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt);
|
||||||
|
auto dtypeValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype));
|
||||||
|
|
||||||
|
auto resultType = result.getType();
|
||||||
|
auto quantized = rewriter
|
||||||
|
.create<AtenQuantizePerChannelOp>(
|
||||||
|
loc, resultType, op.getSelf(), op.getScales(),
|
||||||
|
op.getZeroPoints(), op.getAxis(), dtypeValue)
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
reprOp.getResult().replaceAllUsesWith(quantized);
|
||||||
|
rewriter.eraseOp(reprOp);
|
||||||
|
rewriter.replaceOp(op, quantized);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) {
|
||||||
|
return commuteQuantizedConvolution<AtenQuantizePerChannelOp>(op, rewriter);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// NonValueTensorLiteralOp
|
// NonValueTensorLiteralOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -7446,6 +7446,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %3 : !torch.int\n"
|
" return %3 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
|
||||||
|
" return %arg4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int) -> !torch.int {\n"
|
||||||
|
" return %arg7 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||||
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
|
|
@ -615,6 +615,17 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
|
||||||
/*skipRankCheck=*/true);
|
/*skipRankCheck=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Type getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
|
||||||
|
MLIRContext *context, ArrayRef<const ValueKnowledge *> tensors) {
|
||||||
|
auto promotedType =
|
||||||
|
getPromotedResultTypeAssumingNonZeroRank(context, tensors);
|
||||||
|
if (promotedType.isSignedInteger(8))
|
||||||
|
return mlir::IntegerType::get(context, 32, IntegerType::Signed);
|
||||||
|
if (promotedType.isUnsignedInteger(8))
|
||||||
|
return mlir::IntegerType::get(context, 32, IntegerType::Unsigned);
|
||||||
|
return promotedType;
|
||||||
|
}
|
||||||
|
|
||||||
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
||||||
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
||||||
assert(!inputDType ||
|
assert(!inputDType ||
|
||||||
|
@ -733,8 +744,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
AtenMseLossOp>(op)) {
|
AtenMseLossOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
knowledge.dtype =
|
||||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
|
getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion(
|
||||||
|
op->getContext(),
|
||||||
|
{&operands[0]->getValue(), &operands[1]->getValue()});
|
||||||
incorporateKnowledge(op->getResult(0), knowledge);
|
incorporateKnowledge(op->getResult(0), knowledge);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1034,6 +1034,18 @@ def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N
|
||||||
else:
|
else:
|
||||||
assert False, "Unsupported dtype"
|
assert False, "Unsupported dtype"
|
||||||
|
|
||||||
|
def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aten〇quantize_per_tensor〡dtype(self_rank: int, self_dtype: int, scale: float, zero_point: int, dtype: int) -> int:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aten〇quantize_per_channel〡dtype(self_rank: int, self_dtype: int, scales_rank: int, scales_dtype: int, zero_points_rank: int, zero_points_dtype: int, axis: int, dtype: int) -> int:
|
||||||
|
return dtype
|
||||||
|
|
||||||
class DummyClassType:
|
class DummyClassType:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -650,6 +650,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
|
||||||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||||
|
|
||||||
|
# quantization ops
|
||||||
|
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)", has_canonicalizer=True)
|
||||||
|
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `prim::` namespace.
|
# `prim::` namespace.
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
|
Loading…
Reference in New Issue