mirror of https://github.com/llvm/torch-mlir
Prototype passes for lowering quantized group matmul (#2402)
* Support brevitas custom op (#2320) * f16 change for brevitas * Adapt the change of brevitas quant custom op name * Add unit tests * Make brevitas conversions isolated * Address the comments --------- Co-authored-by: dan <danimal197@gmail.com>pull/2426/head
parent
c42d2beb6e
commit
1682b540bf
|
@ -57,6 +57,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createFinalizingBackendTypeConversionPass();
|
createFinalizingBackendTypeConversionPass();
|
||||||
|
|
||||||
|
// These passes do a one-off conversion of a specific kind of quantized group
|
||||||
|
// matmul as a prototype. Generalized quantized operation handling will likely
|
||||||
|
// obviate them but that are being carried for now in order to unblock progress
|
||||||
|
// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for
|
||||||
|
// the plan to support a more generalized lowering for these graphs.
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createVerifyLinalgOnTensorsBackendContractPass();
|
createVerifyLinalgOnTensorsBackendContractPass();
|
||||||
|
|
||||||
|
|
|
@ -48,4 +48,16 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra
|
||||||
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
||||||
}
|
}
|
||||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||||
|
|
||||||
|
// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
|
||||||
|
// They should not be included in default lowering flows until further along.
|
||||||
|
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
|
||||||
|
let summary = "Unpack quantized int4 tensor from int8 containter";
|
||||||
|
let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()";
|
||||||
|
}
|
||||||
|
|
||||||
|
def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> {
|
||||||
|
let summary = "Convert torch custom quant op to linalg";
|
||||||
|
let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()";
|
||||||
|
}
|
||||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||||
|
|
|
@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) {
|
||||||
if (type.isSignless() && type.getWidth() == 1)
|
if (type.isSignless() && type.getWidth() == 1)
|
||||||
return true;
|
return true;
|
||||||
if (type.isSigned()) {
|
if (type.isSigned()) {
|
||||||
for (unsigned width : {8, 16, 32, 64}) {
|
for (unsigned width : {4, 8, 16, 32, 64}) {
|
||||||
if (type.getWidth() == width)
|
if (type.getWidth() == width)
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (type.isUnsigned()) {
|
if (type.isUnsigned()) {
|
||||||
return type.getWidth() == 8;
|
return type.getWidth() == 8 || type.getWidth() == 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -25,6 +25,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses
|
||||||
BackendTypeConversion.cpp
|
BackendTypeConversion.cpp
|
||||||
BackendTypeConversionPasses.cpp
|
BackendTypeConversionPasses.cpp
|
||||||
Passes.cpp
|
Passes.cpp
|
||||||
|
ConvertCustomQuantOp.cpp
|
||||||
|
UnpackQuantTensor.cpp
|
||||||
VerifyLinalgOnTensorsBackendContract.cpp
|
VerifyLinalgOnTensorsBackendContract.cpp
|
||||||
VerifyTosaBackendContract.cpp
|
VerifyTosaBackendContract.cpp
|
||||||
VerifyStablehloBackendContract.cpp
|
VerifyStablehloBackendContract.cpp
|
||||||
|
|
|
@ -0,0 +1,225 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertCustomQuantizedMatmulOp : public OpConversionPattern<OperatorOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
if (op.getName().str() != "quant.matmul_rhs_group_quant") {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get inputs: lhs, rhsQuant, scales, zps
|
||||||
|
Value lhs = adaptor.getOperands()[0];
|
||||||
|
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||||
|
if (!lhsType) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto lhsShape = lhsType.getShape();
|
||||||
|
int lhsReductDimSize = lhsShape.back();
|
||||||
|
|
||||||
|
Value rhsQuant = adaptor.getOperands()[1];
|
||||||
|
auto rhsType = rhsQuant.getType().cast<RankedTensorType>();
|
||||||
|
if (!rhsType) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto rhsShape = rhsType.getShape();
|
||||||
|
int rhsReductDimSize = rhsShape.back();
|
||||||
|
Type rhsElementType = rhsType.getElementType();
|
||||||
|
|
||||||
|
Value scales = adaptor.getOperands()[2];
|
||||||
|
Value zps = adaptor.getOperands()[3];
|
||||||
|
Value unpackedTypeWidth = adaptor.getOperands()[4];
|
||||||
|
Value groupSize = adaptor.getOperands()[5];
|
||||||
|
|
||||||
|
auto getConstantIntegerFromDefiningOp = [](Value operand,
|
||||||
|
int &extractedInt) {
|
||||||
|
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
|
||||||
|
if (!castOp) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto constOp =
|
||||||
|
dyn_cast<Torch::ConstantIntOp>(castOp.getOperand(0).getDefiningOp());
|
||||||
|
if (!constOp) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
extractedInt = constOp.getValue();
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
int gs;
|
||||||
|
if (failed(getConstantIntegerFromDefiningOp(groupSize, gs))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
int unpackedBitWidth;
|
||||||
|
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (unpackedBitWidth != rhsElementType.getIntOrFloatBitWidth()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// get outputs
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType(0));
|
||||||
|
auto resultType = newResultType.cast<RankedTensorType>();
|
||||||
|
if (!resultType) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto resultShape = resultType.getShape();
|
||||||
|
Type elementType = resultType.getElementType();
|
||||||
|
|
||||||
|
// expand lhs
|
||||||
|
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
|
||||||
|
lhsReductDimSize / gs, gs};
|
||||||
|
RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType);
|
||||||
|
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
|
||||||
|
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, lhsExpandedType, lhs, lhsReassociation);
|
||||||
|
|
||||||
|
// expand rhs
|
||||||
|
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs};
|
||||||
|
RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType);
|
||||||
|
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
|
||||||
|
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
|
loc, rhsExpandedType, rhsQuant, rhsReassociation);
|
||||||
|
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, FloatAttr::get(elementType, 0.0));
|
||||||
|
|
||||||
|
Value emptyDequant = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, rhsExpandedShape, elementType);
|
||||||
|
SmallVector<Value> dynDims;
|
||||||
|
for (int i = 0; i < lhsType.getRank(); i++) {
|
||||||
|
if (lhsType.isDynamicDim(i)) {
|
||||||
|
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, resultShape, elementType, dynDims);
|
||||||
|
Value output = rewriter.create<linalg::FillOp>(
|
||||||
|
loc, cst0, empty).getResult(0);
|
||||||
|
|
||||||
|
AffineExpr d0, d1, d2, d3, d4;
|
||||||
|
bindDims(getContext(), d0, d1, d2, d3, d4);
|
||||||
|
auto c0 = rewriter.getAffineConstantExpr(0);
|
||||||
|
auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext());
|
||||||
|
auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext());
|
||||||
|
auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext());
|
||||||
|
auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext());
|
||||||
|
auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext());
|
||||||
|
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
|
||||||
|
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> dequantIteratorTypes(3, utils::IteratorType::parallel);
|
||||||
|
SmallVector<utils::IteratorType> matmulIteratorTypes = {
|
||||||
|
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
||||||
|
utils::IteratorType::parallel, utils::IteratorType::reduction,
|
||||||
|
utils::IteratorType::reduction
|
||||||
|
};
|
||||||
|
|
||||||
|
Value rhsDequant =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, emptyDequant.getType(),
|
||||||
|
ValueRange{rhsExpanded, scales, zps}, emptyDequant,
|
||||||
|
/*indexingMaps=*/dqIndexingMaps,
|
||||||
|
/*iteratorTypes=*/dequantIteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value w = args[0], scale = args[1], zeroPoint = args[2];
|
||||||
|
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
|
||||||
|
Value fp_extw = b.create<arith::UIToFPOp>(loc, rewriter.getF16Type(), extw);
|
||||||
|
Value shifted = b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
|
||||||
|
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
|
||||||
|
b.create<linalg::YieldOp>(loc, dqw);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
Value matmulDequant =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, output.getType(),
|
||||||
|
ValueRange{lhsExpanded, rhsDequant}, output,
|
||||||
|
/*indexingMaps=*/matIndexingMaps,
|
||||||
|
/*iteratorTypes=*/matmulIteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value l = args[0], r = args[1], out = args[2];
|
||||||
|
Value pd = b.create<arith::MulFOp>(loc, l, r);
|
||||||
|
Value ac = b.create<arith::AddFOp>(loc, pd, out);
|
||||||
|
b.create<linalg::YieldOp>(loc, ac);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmulDequant);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertCustomQuantOpPass
|
||||||
|
: public TorchConversion::ConvertCustomQuantOpBase<ConvertCustomQuantOpPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<arith::ArithDialect>();
|
||||||
|
registry.insert<func::FuncDialect>();
|
||||||
|
registry.insert<linalg::LinalgDialect>();
|
||||||
|
registry.insert<tensor::TensorDialect>();
|
||||||
|
registry.insert<Torch::TorchDialect>();
|
||||||
|
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
ConversionTarget target(*context);
|
||||||
|
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
|
||||||
|
tensor::TensorDialect, arith::ArithDialect,
|
||||||
|
Torch::TorchDialect>();
|
||||||
|
|
||||||
|
TypeConverter typeConverter;
|
||||||
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
target.addIllegalOp<OperatorOp>();
|
||||||
|
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
|
||||||
|
|
||||||
|
if (failed(
|
||||||
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::TorchConversion::createConvertCustomQuantOpPass() {
|
||||||
|
return std::make_unique<ConvertCustomQuantOpPass>();
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
// Also available under a BSD-style license. See LICENSE.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "PassDetail.h"
|
||||||
|
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class UnpackQuantizedMatmulWeights
|
||||||
|
: public OpRewritePattern<ValueTensorLiteralOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (!constOp->hasOneUse())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
OpOperand *use = constOp.getResult().use_begin().getOperand();
|
||||||
|
auto op = dyn_cast<OperatorOp>(use->getOwner());
|
||||||
|
if (!op) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (op.getName().str() != "quant.matmul_rhs_group_quant") {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use->getOperandNumber() != 1) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value rhs = op.getOperand(1);
|
||||||
|
Value bitWidth = op.getOperand(4);
|
||||||
|
|
||||||
|
auto getConstantIntegerFromDefiningOp = [](Value operand,
|
||||||
|
int &extractedInt) {
|
||||||
|
auto constOp = dyn_cast<Torch::ConstantIntOp>(operand.getDefiningOp());
|
||||||
|
if (!constOp) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
extractedInt = constOp.getValue();
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
int unpackedBitWidth;
|
||||||
|
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto rhsType = rhs.getType().dyn_cast<ValueTensorType>();
|
||||||
|
if (!rhsType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!rhsType.hasDtype())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Type dType = rhsType.getDtype();
|
||||||
|
int dTypeWidth = dType.getIntOrFloatBitWidth();
|
||||||
|
if (dTypeWidth == unpackedBitWidth)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!rhsType.hasSizes())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> tensorShape(rhsType.getSizes());
|
||||||
|
if (tensorShape.back() == kUnknownSize)
|
||||||
|
return failure();
|
||||||
|
int packRatio = dTypeWidth / unpackedBitWidth;
|
||||||
|
|
||||||
|
tensorShape[tensorShape.size() - 1] *= packRatio;
|
||||||
|
Type unpackedElementType;
|
||||||
|
if (dType.isSignedInteger())
|
||||||
|
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true);
|
||||||
|
else
|
||||||
|
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false);
|
||||||
|
ValueTensorType newRhsType = ValueTensorType::get(
|
||||||
|
rewriter.getContext(), tensorShape, unpackedElementType);
|
||||||
|
|
||||||
|
auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>();
|
||||||
|
if (!elements)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto attrType = RankedTensorType::get(tensorShape, unpackedElementType);
|
||||||
|
|
||||||
|
// TODO: Materialize IR that does the conversion from quantized type to
|
||||||
|
// pure integer type which relys on constant evaluation in backends
|
||||||
|
auto data = elements.getRawData();
|
||||||
|
std::vector<APInt> newData(data.size() * packRatio,
|
||||||
|
APInt(unpackedBitWidth, 0));
|
||||||
|
for (int i = 0, e = data.size(); i < e; ++i) {
|
||||||
|
auto el = data[i];
|
||||||
|
char mask = (1 << unpackedBitWidth) - 1;
|
||||||
|
for (int b = 0; b < packRatio; b++) {
|
||||||
|
newData[i * packRatio + b] =
|
||||||
|
APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b));
|
||||||
|
mask = mask << unpackedBitWidth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<ValueTensorLiteralOp>(
|
||||||
|
constOp, newRhsType,
|
||||||
|
DenseElementsAttr::get(attrType, ArrayRef<APInt>(newData)));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class UnpackQuantTensorPass
|
||||||
|
: public TorchConversion::UnpackQuantTensorBase<UnpackQuantTensorPass> {
|
||||||
|
using UnpackQuantTensorBase<UnpackQuantTensorPass>::UnpackQuantTensorBase;
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<func::FuncDialect>();
|
||||||
|
registry.insert<Torch::TorchDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
RewritePatternSet patterns(context);
|
||||||
|
patterns.add<UnpackQuantizedMatmulWeights>(context);
|
||||||
|
|
||||||
|
if (failed(
|
||||||
|
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
|
mlir::torch::TorchConversion::createUnpackQuantTensorPass() {
|
||||||
|
return std::make_unique<UnpackQuantTensorPass>();
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
|
||||||
|
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
|
||||||
|
// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
|
||||||
|
// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||||
|
// CHECK-LABEL: func @forward
|
||||||
|
func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2],f16> {
|
||||||
|
%q_rhs = torch.vtensor.literal(dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
|
||||||
|
%scales = torch.vtensor.literal(dense<1.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||||
|
%zps = torch.vtensor.literal(dense<0.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||||
|
%bit_width = torch.constant.int 8
|
||||||
|
%group_size = torch.constant.int 2
|
||||||
|
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,2],f16>, !torch.vtensor<[2,2],ui8>, !torch.vtensor<[2,1,1],f16>, !torch.vtensor<[2,1,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f16>
|
||||||
|
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2],f16> -> tensor<1x1x2xf16>
|
||||||
|
// CHECK: %[[TENSOR1:.*]] = torch.vtensor.literal(dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
|
||||||
|
// CHECK: %[[QUANT_RHS:.*]] = torch_c.to_builtin_tensor %[[TENSOR1]] : !torch.vtensor<[2,2],ui8> -> tensor<2x2xi8>
|
||||||
|
// CHECK: %[[TENSOR2:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||||
|
// CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
|
||||||
|
// CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||||
|
// CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
|
||||||
|
// CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16>
|
||||||
|
// CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8>
|
||||||
|
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
|
||||||
|
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16>
|
||||||
|
// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16>
|
||||||
|
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[CST]] : f16) outs(%[[EMPTY2]] : tensor<1x1x2xf16>) -> tensor<1x1x2xf16>
|
||||||
|
// CHECK: %[[DEQUANT_RHS:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_RHS]], %[[SCALES]], %[[ZPS]] : tensor<2x1x2xi8>, tensor<2x1x1xf16>, tensor<2x1x1xf16>) outs(%[[EMPTY1]] : tensor<2x1x2xf16>) {
|
||||||
|
// CHECK-NEXT: ^bb0(%[[WEIGHTS:.*]]: i8, %[[SCALES:.*]]: f16, %[[ZPS:.*]]: f16, %{{.*}}: f16):
|
||||||
|
// CHECK-NEXT: %[[EXTUI:.*]] = arith.extui %[[WEIGHTS]] : i8 to i32
|
||||||
|
// CHECK-NEXT: %[[UITOFP:.*]] = arith.uitofp %[[EXTUI]] : i32 to f16
|
||||||
|
// CHECK-NEXT: %[[SUBF:.*]] = arith.subf %[[UITOFP]], %[[ZPS]] : f16
|
||||||
|
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[SUBF]], %[[SCALES]] : f16
|
||||||
|
// CHECK-NEXT: linalg.yield %[[MULF]] : f16
|
||||||
|
// CHECK-NEXT: } -> tensor<2x1x2xf16>
|
||||||
|
// CHECK: %[[MATMUL:.*]] = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[EXPANDED_LHS]], %[[DEQUANT_RHS]] : tensor<1x1x1x2xf16>, tensor<2x1x2xf16>) outs(%[[OUT]] : tensor<1x1x2xf16>) {
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: f16, %[[OUT:.*]]: f16):
|
||||||
|
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f16
|
||||||
|
// CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[MULF]], %[[OUT]] : f16
|
||||||
|
// CHECK-NEXT: linalg.yield %[[ADDF]] : f16
|
||||||
|
// CHECK-NEXT: } -> tensor<1x1x2xf16>
|
||||||
|
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<1x1x2xf16> to tensor<1x1x2xf16>
|
||||||
|
return %output : !torch.vtensor<[1,1,2],f16>
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @forward
|
||||||
|
func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {
|
||||||
|
%q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8>
|
||||||
|
// CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4>
|
||||||
|
%scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||||
|
%zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||||
|
%bit_width = torch.constant.int 4
|
||||||
|
%group_size = torch.constant.int 2
|
||||||
|
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
|
||||||
|
return %output : !torch.vtensor<[1,1,8],f16>
|
||||||
|
}
|
Loading…
Reference in New Issue