From 23ced3dca6a200175400de0d1813a06cc4040bb6 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 16 Jan 2023 01:17:29 -0500 Subject: [PATCH] Add Brevitas quantization for per tensor and per channel modes --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 78 ++++++++ lib/Conversion/TorchToLinalg/Linear.cpp | 126 ++++++++++++- .../TorchToLinalg/Uncategorized.cpp | 84 +++++++-- lib/Conversion/TorchToLinalg/Utils.cpp | 3 +- lib/Conversion/Utils/Utils.cpp | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 172 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 12 ++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 17 +- .../build_tools/abstract_interp_lib_gen.py | 12 ++ .../jit_ir/build_tools/torch_ods_gen.py | 5 + 10 files changed, 487 insertions(+), 24 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9aa76ff35..815217eba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10635,6 +10635,84 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scales, + AnyTorchTensorType:$zero_points, + Torch_IntType:$axis, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIntReprOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 92ed647ef..5513c5db0 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -474,6 +475,109 @@ public: }; } // namespace +namespace { +class ConvertAtenQuantizePerChannelOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenQuantizePerChannelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + Value self = adaptor.getSelf(); + auto selfRank = + adaptor.getSelf().getType().cast().getRank(); + Type inputDtype = + adaptor.getSelf().getType().cast().getElementType(); + if (!inputDtype.isa()) { + return rewriter.notifyMatchFailure( + op, "input tensor must be floating point dtype for quantization"); + } + + Type newResultType = getTypeConverter()->convertType(op.getType()); + Type outputDtype = newResultType.cast().getElementType(); + if (!outputDtype.isa()) { + return rewriter.notifyMatchFailure( + op, "target quantization type must be an integer type"); + } + + int64_t axis; + if (!matchPattern(op.getAxis(), m_TorchConstantInt(&axis))) + return rewriter.notifyMatchFailure(op, "only constant axis supported"); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, self), outputDtype); + + SmallVector iteratorTypes( + selfRank, utils::IteratorType::parallel); + + AffineMap quantAxisMap = + AffineMap::get(selfRank, 0, rewriter.getAffineDimExpr(axis), context); + AffineMap identity = AffineMap::getMultiDimIdentityMap(selfRank, context); + + SmallVector indexingMaps{identity, quantAxisMap, quantAxisMap, + identity}; + + Value quantized = + rewriter + .create( + loc, newResultType, + ValueRange{self, adaptor.getScales(), adaptor.getZeroPoints()}, + initTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value lhs = convertScalarToDtype(b, loc, args[0], inputDtype); + Value scale = + convertScalarToDtype(b, loc, args[1], inputDtype); + Value scaled = b.create(loc, lhs, scale); + Value rounded = b.create(loc, scaled); + Type intermediateDtype = b.getIntegerType( + cast(inputDtype).getWidth()); + Value fpToI = + convertScalarToDtype(b, loc, rounded, intermediateDtype); + + Value zeroPoint = + convertScalarToDtype(b, loc, args[2], intermediateDtype); + Value shifted = + b.create(loc, fpToI, zeroPoint); + + Value quantMin = b.create( + loc, + b.getIntegerAttr( + intermediateDtype, + llvm::minIntN(cast(outputDtype) + .getWidth()))); + Value quantMax = b.create( + loc, + b.getIntegerAttr( + intermediateDtype, + llvm::maxIntN(cast(outputDtype) + .getWidth()))); + Value minCompare = b.create( + loc, arith::CmpIPredicate::slt, shifted, quantMin); + Value minClamp = b.create(loc, minCompare, + quantMin, shifted); + + Value maxCompare = b.create( + loc, arith::CmpIPredicate::sgt, minClamp, quantMax); + Value maxClamp = b.create( + loc, maxCompare, quantMax, minClamp); + + Value truncated = + convertScalarToDtype(b, loc, maxClamp, outputDtype); + b.create(loc, truncated); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, newResultType, quantized); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenConvolutionOp : public OpConversionPattern { public: @@ -493,8 +597,8 @@ public: Type elementType = input.getType().cast().getElementType(); - if (!elementType.isa()) - return op.emitError("unimplemented: non-floating point type"); + // if (!elementType.isa()) + // return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); size_t numSpacialDims = inRank - 2; if (numSpacialDims != 2) @@ -661,21 +765,25 @@ public: castIndexToInt(weightDims[i]), strideIntValues[i])); } + Type outputElementType = elementType; + if (outputElementType.isInteger(8)) + outputElementType = rewriter.getIntegerType(32); + Value initTensor = rewriter.create( - loc, getAsOpFoldResult(outDims), elementType); + loc, getAsOpFoldResult(outDims), outputElementType); Value bias = adaptor.getBias(); Value outputTensor; if (bias.getType().isa()) { - Value c0float = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - outputTensor = rewriter.create(loc, c0float, initTensor) - .getResult(0); + Value c0Val = rewriter.create( + loc, rewriter.getZeroAttr(outputElementType)); + outputTensor = + rewriter.create(loc, c0Val, initTensor).getResult(0); } else { auto biasType = bias.getType().cast(); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); - if (elementType != biasType.getElementType()) + if (outputElementType != biasType.getElementType()) return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); auto resultRank = initTensor.getType().cast().getRank(); @@ -840,6 +948,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index bc16c8c1e..6f847086f 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -45,8 +45,7 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, if (IntegerType intType = type.dyn_cast()) { if (intType.isUnsigned()) return b.create(loc, iupred, lhs, rhs); - if (intType.isSigned()) - return b.create(loc, ispred, lhs, rhs); + return b.create(loc, ispred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } @@ -798,10 +797,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); @@ -813,14 +808,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto result = payloadArgs[0]; if (!min.getType().isa()) { auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); + auto pred = createComparisonTemplate( + b, loc, dtype, result, minPromoted); result = b.create(loc, pred, minPromoted, result); } if (!max.getType().isa()) { auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); + auto pred = createComparisonTemplate( + b, loc, dtype, result, maxPromoted); result = b.create(loc, pred, maxPromoted, result); } return result; @@ -995,6 +994,65 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return convertScalarToDtype(b, loc, payloadArgs[1], dtype); } + if (auto quantizePerTensor = dyn_cast(op)) { + Type inputDtype = + converter->convertType(quantizePerTensor.getSelf().getType()) + .cast() + .getElementType(); + if (!inputDtype.isa()) { + quantizePerTensor.emitError( + "input tensor must be floating point dtype for quantization"); + return nullptr; + } + + Type outputDtype = + converter->convertType(quantizePerTensor.getResult().getType()) + .cast() + .getElementType(); + if (!outputDtype.isa()) { + quantizePerTensor.emitError( + "target quantization type must be an integer type"); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], inputDtype); + Value scale = convertScalarToDtype(b, loc, operands[1], inputDtype); + Value scaled = b.create(loc, lhs, scale); + Value rounded = b.create(loc, scaled); + Type intermediateDtype = + b.getIntegerType(cast(inputDtype).getWidth()); + Value fpToI = convertScalarToDtype(b, loc, rounded, intermediateDtype); + + Value zeroPoint = + convertScalarToDtype(b, loc, operands[2], intermediateDtype); + Value shifted = b.create(loc, fpToI, zeroPoint); + + Value quantMin = b.create( + loc, + b.getIntegerAttr( + intermediateDtype, + llvm::minIntN(cast(outputDtype).getWidth()))); + Value quantMax = b.create( + loc, + b.getIntegerAttr( + intermediateDtype, + llvm::maxIntN(cast(outputDtype).getWidth()))); + Value minCompare = createComparisonTemplate( + b, loc, intermediateDtype, shifted, quantMin); + Value minClamp = + b.create(loc, minCompare, quantMin, shifted); + + Value maxCompare = createComparisonTemplate( + b, loc, intermediateDtype, minClamp, quantMax); + Value maxClamp = + b.create(loc, maxCompare, quantMax, minClamp); + + return convertScalarToDtype(b, loc, maxClamp, outputDtype); + } + if (auto triu = dyn_cast(op)) { // Check if the rank of the input tensor is valid. AtenTriuOp::Adaptor adaptor(operands); @@ -1088,7 +1146,8 @@ public: AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp, - AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op)) + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1578,4 +1637,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); + + // Quantization + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 58f807e74..f7cc638fb 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -94,8 +94,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( SmallVector(inRank, kUnknownSize))), elementType); - Value cf0 = - b.create(loc, b.getFloatAttr(elementType, 0.0)); + Value cf0 = b.create(loc, b.getZeroAttr(elementType)); SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); return b.create(loc, inputType, input, /*low=*/paddingValues, diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 906cc3c44..2f66e9dd2 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -263,7 +263,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, // We only support conversion from Byte or Char scalarType not to Byte or Char // dtype. - if (isByteOrChar(dtype)) { + if (isByteOrChar(dtype) && !scalarType.isa()) { mlir::emitError(loc) << "unsupported: conversion to byte or char type for " "convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 73cb0c9a8..2f45e39e1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -29,6 +29,17 @@ using namespace mlir::torch::Torch; // Utilities //===----------------------------------------------------------------------===// +static inline torch_upstream::ScalarType +materializeQType(torch_upstream::ScalarType t) { + if (t == torch_upstream::ScalarType::QInt8) + return torch_upstream::ScalarType::Char; + if (t == torch_upstream::ScalarType::QUInt8) + return torch_upstream::ScalarType::Char; + if (t == torch_upstream::ScalarType::QInt32) + return torch_upstream::ScalarType::Char; + return t; +} + Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, @@ -1393,6 +1404,167 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenQuantizePerTensorOp +//===----------------------------------------------------------------------===// + +static void transposeOutputChannels(PatternRewriter &rewriter, Location loc, + Operation *target, Value other) { + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + AtenTransposeIntOp transposed = rewriter.create( + loc, + Torch::NonValueTensorType::getWithLeastStaticInformation( + target->getContext()), + other, zero, one); + transposed->moveBefore(target); + target->replaceUsesOfWith(other, transposed.getResult()); +} + +template +static LogicalResult commuteQuantizedConvolution(QuantizationOp op, + PatternRewriter &rewriter) { + auto result = op.getResult(); + if (!result.hasOneUse()) + return rewriter.notifyMatchFailure(op, "quantize op has multiple uses"); + + if (auto clampOp = dyn_cast(*result.getUsers().begin())) { + result = clampOp.getResult(); + } + + AtenSubTensorOp subZeroPoint; + if (!(subZeroPoint = dyn_cast(*result.getUsers().begin()))) { + return rewriter.notifyMatchFailure( + op, "quantize op does not have sub tensor as user"); + } + auto subResult = subZeroPoint.getResult(); + + AtenMulTensorOp mulScale; + if (!(mulScale = dyn_cast(*subResult.getUsers().begin()))) { + return rewriter.notifyMatchFailure( + op, "quantize op does not have mul tensor in chain"); + } + auto mulResult = mulScale.getResult(); + + Aten_ConvolutionOp conv; + if (!(conv = dyn_cast(*mulResult.getUsers().begin()))) { + return rewriter.notifyMatchFailure( + op, "quantize op does not have convolution in chain"); + } + + auto convResult = conv.getResult(); + conv->replaceUsesOfWith(mulResult, result); + convResult.replaceAllUsesWith(mulResult); + subZeroPoint->replaceUsesOfWith(result, convResult); + subZeroPoint->moveAfter(conv); + mulScale->moveAfter(subZeroPoint); + + if (isa(op)) { + Value other; + if (subZeroPoint.getSelf() == convResult) { + other = subZeroPoint.getOther(); + } else { + other = subZeroPoint.getSelf(); + } + Location otherLoc = conv->getLoc(); + transposeOutputChannels(rewriter, otherLoc, subZeroPoint, other); + + if (mulScale.getSelf() == subResult) { + other = mulScale.getOther(); + } else { + other = mulScale.getSelf(); + } + transposeOutputChannels(rewriter, otherLoc, mulScale, other); + } + return success(); +} + +void AtenQuantizePerTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto result = op.getResult(); + if (!result.hasOneUse()) + return rewriter.notifyMatchFailure( + op, "quantize per tensor op has multiple uses"); + + AtenIntReprOp reprOp; + if (!(reprOp = dyn_cast(*result.getUsers().begin()))) + return rewriter.notifyMatchFailure( + op, "quantize per tensor op use must be aten.int_repr"); + + int64_t dtypeInt; + if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) { + return failure(); + } + + auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt); + auto dtypeValue = rewriter.create( + loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype)); + + auto resultType = result.getType(); + auto quantized = rewriter + .create( + loc, resultType, op.getSelf(), op.getScale(), + op.getZeroPoint(), dtypeValue) + .getResult(); + + reprOp.getResult().replaceAllUsesWith(quantized); + rewriter.eraseOp(reprOp); + rewriter.replaceOp(op, quantized); + return success(); + }); + patterns.add(+[](AtenQuantizePerTensorOp op, PatternRewriter &rewriter) { + return commuteQuantizedConvolution(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenQuantizePerChannelOp +//===----------------------------------------------------------------------===// + +void AtenQuantizePerChannelOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) { + auto loc = op.getLoc(); + auto result = op.getResult(); + if (!result.hasOneUse()) + return rewriter.notifyMatchFailure( + op, "quantize per channel op has multiple uses"); + + AtenIntReprOp reprOp; + if (!(reprOp = dyn_cast(*result.getUsers().begin()))) + return rewriter.notifyMatchFailure( + op, "quantize per channel op use must be aten.int_repr"); + + int64_t dtypeInt; + if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) { + return failure(); + } + + auto scalarDtype = materializeQType((torch_upstream::ScalarType)dtypeInt); + auto dtypeValue = rewriter.create( + loc, rewriter.getI64IntegerAttr((int64_t)scalarDtype)); + + auto resultType = result.getType(); + auto quantized = rewriter + .create( + loc, resultType, op.getSelf(), op.getScales(), + op.getZeroPoints(), op.getAxis(), dtypeValue) + .getResult(); + + reprOp.getResult().replaceAllUsesWith(quantized); + rewriter.eraseOp(reprOp); + rewriter.replaceOp(op, quantized); + return success(); + }); + patterns.add(+[](AtenQuantizePerChannelOp op, PatternRewriter &rewriter) { + return commuteQuantizedConvolution(op, rewriter); + }); +} + //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2e5960d66..5749b7929 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7446,6 +7446,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" return %arg4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int) -> !torch.int {\n" +" return %arg7 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 568c99f84..eceb85a8e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -615,6 +615,17 @@ static Type getPromotedResultTypeAssumingNonZeroRank( /*skipRankCheck=*/true); } +static Type getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion( + MLIRContext *context, ArrayRef tensors) { + auto promotedType = + getPromotedResultTypeAssumingNonZeroRank(context, tensors); + if (promotedType.isSignedInteger(8)) + return mlir::IntegerType::get(context, 32, IntegerType::Signed); + if (promotedType.isUnsignedInteger(8)) + return mlir::IntegerType::get(context, 32, IntegerType::Unsigned); + return promotedType; +} + void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType( ValueKnowledge &knowledge, Value dtype, Type inputDType) { assert(!inputDType || @@ -733,8 +744,10 @@ void TypeAnalysis::visitOperation(Operation *op, AtenMseLossOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); + knowledge.dtype = + getPromotedResultTypeAssumingNonZeroRankWithQuantizedPromotion( + op->getContext(), + {&operands[0]->getValue(), &operands[1]->getValue()}); incorporateKnowledge(op->getResult(0), knowledge); return; } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index ccc2b9d43..0b1bb27a2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1034,6 +1034,18 @@ def aten〇fft_fft〡dtype(self_rank: int, self_dtype: int, n: Optional[int] = N else: assert False, "Unsupported dtype" +def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: + return self + +def aten〇quantize_per_tensor〡dtype(self_rank: int, self_dtype: int, scale: float, zero_point: int, dtype: int) -> int: + return dtype + +def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]: + return self + +def aten〇quantize_per_channel〡dtype(self_rank: int, self_dtype: int, scales_rank: int, scales_dtype: int, zero_points_rank: int, zero_points_dtype: int, axis: int, dtype: int) -> int: + return dtype + class DummyClassType: def __init__(self): pass diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index addb4aba0..2c025dec0 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -650,6 +650,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + # quantization ops + emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)", has_canonicalizer=True) + emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)", has_canonicalizer=True) + emit("aten::int_repr : (Tensor) -> (Tensor)") + # ========================================================================== # `prim::` namespace. # ==========================================================================