From 1682b540bfe37d3cadab582cda56cb723467b1f1 Mon Sep 17 00:00:00 2001 From: jinchen62 <49575973+jinchen62@users.noreply.github.com> Date: Tue, 29 Aug 2023 21:25:45 -0700 Subject: [PATCH] Prototype passes for lowering quantized group matmul (#2402) * Support brevitas custom op (#2320) * f16 change for brevitas * Adapt the change of brevitas quant custom op name * Add unit tests * Make brevitas conversions isolated * Address the comments --------- Co-authored-by: dan --- .../TorchConversion/Transforms/Passes.h | 8 + .../TorchConversion/Transforms/Passes.td | 12 + lib/Dialect/Torch/IR/TorchTypes.cpp | 4 +- .../TorchConversion/Transforms/CMakeLists.txt | 2 + .../Transforms/ConvertCustomQuantOp.cpp | 225 ++++++++++++++++++ .../Transforms/UnpackQuantTensor.cpp | 143 +++++++++++ .../convert-custom-quant-op.mlir | 45 ++++ .../TorchConversion/unpack-quant-tensor.mlir | 13 + 8 files changed, 450 insertions(+), 2 deletions(-) create mode 100644 lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp create mode 100644 lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp create mode 100644 test/Dialect/TorchConversion/convert-custom-quant-op.mlir create mode 100644 test/Dialect/TorchConversion/unpack-quant-tensor.mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index e6493a154..d762bd840 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -57,6 +57,14 @@ std::unique_ptr> createFuncBackendTypeConversionPass(); std::unique_ptr> createFinalizingBackendTypeConversionPass(); +// These passes do a one-off conversion of a specific kind of quantized group +// matmul as a prototype. Generalized quantized operation handling will likely +// obviate them but that are being carried for now in order to unblock progress +// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for +// the plan to support a more generalized lowering for these graphs. +std::unique_ptr> createUnpackQuantTensorPass(); +std::unique_ptr> createConvertCustomQuantOpPass(); + std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index cb58dbbd9..4d3e16a81 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -48,4 +48,16 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; } #endif // TORCH_MLIR_ENABLE_STABLEHLO + +// The following passes are for a one-off conversion of a specific kind of quantized group matmul. +// They should not be included in default lowering flows until further along. +def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { + let summary = "Unpack quantized int4 tensor from int8 containter"; + let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; +} + +def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { + let summary = "Convert torch custom quant op to linalg"; + let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; +} #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 8eb844cbd..1c8d3c6f7 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) { if (type.isSignless() && type.getWidth() == 1) return true; if (type.isSigned()) { - for (unsigned width : {8, 16, 32, 64}) { + for (unsigned width : {4, 8, 16, 32, 64}) { if (type.getWidth() == width) return true; } } if (type.isUnsigned()) { - return type.getWidth() == 8; + return type.getWidth() == 8 || type.getWidth() == 4; } } return false; diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 1f7f4e8f8..6495e4682 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -25,6 +25,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp BackendTypeConversionPasses.cpp Passes.cpp + ConvertCustomQuantOp.cpp + UnpackQuantTensor.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp VerifyStablehloBackendContract.cpp diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp new file mode 100644 index 000000000..f6432c602 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -0,0 +1,225 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getName().str() != "quant.matmul_rhs_group_quant") { + return failure(); + } + Location loc = op->getLoc(); + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { + return failure(); + } + + // get inputs: lhs, rhsQuant, scales, zps + Value lhs = adaptor.getOperands()[0]; + auto lhsType = lhs.getType().cast(); + if (!lhsType) { + return failure(); + } + auto lhsShape = lhsType.getShape(); + int lhsReductDimSize = lhsShape.back(); + + Value rhsQuant = adaptor.getOperands()[1]; + auto rhsType = rhsQuant.getType().cast(); + if (!rhsType) { + return failure(); + } + auto rhsShape = rhsType.getShape(); + int rhsReductDimSize = rhsShape.back(); + Type rhsElementType = rhsType.getElementType(); + + Value scales = adaptor.getOperands()[2]; + Value zps = adaptor.getOperands()[3]; + Value unpackedTypeWidth = adaptor.getOperands()[4]; + Value groupSize = adaptor.getOperands()[5]; + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto castOp = dyn_cast(operand.getDefiningOp()); + if (!castOp) { + return failure(); + } + auto constOp = + dyn_cast(castOp.getOperand(0).getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + + int gs; + if (failed(getConstantIntegerFromDefiningOp(groupSize, gs))) { + return failure(); + } + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) { + return failure(); + } + if (unpackedBitWidth != rhsElementType.getIntOrFloatBitWidth()) { + return failure(); + } + + // get outputs + Type newResultType = getTypeConverter()->convertType(op.getType(0)); + auto resultType = newResultType.cast(); + if (!resultType) { + return failure(); + } + auto resultShape = resultType.getShape(); + Type elementType = resultType.getElementType(); + + // expand lhs + std::vector lhsExpandedShape = {lhsShape[0], lhsShape[1], + lhsReductDimSize / gs, gs}; + RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); + SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; + Value lhsExpanded = rewriter.create( + loc, lhsExpandedType, lhs, lhsReassociation); + + // expand rhs + std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs}; + RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); + SmallVector rhsReassociation = {{0}, {1, 2}}; + Value rhsExpanded = rewriter.create( + loc, rhsExpandedType, rhsQuant, rhsReassociation); + Value cst0 = rewriter.create( + loc, FloatAttr::get(elementType, 0.0)); + + Value emptyDequant = rewriter.create( + loc, rhsExpandedShape, elementType); + SmallVector dynDims; + for (int i = 0; i < lhsType.getRank(); i++) { + if (lhsType.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, lhs, i)); + } + } + Value empty = rewriter.create( + loc, resultShape, elementType, dynDims); + Value output = rewriter.create( + loc, cst0, empty).getResult(0); + + AffineExpr d0, d1, d2, d3, d4; + bindDims(getContext(), d0, d1, d2, d3, d4); + auto c0 = rewriter.getAffineConstantExpr(0); + auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()); + auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext()); + auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext()); + auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext()); + auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext()); + SmallVector dqIndexingMaps = {map, map1, map1, map}; + SmallVector matIndexingMaps = {map2, map3, map4}; + + SmallVector dequantIteratorTypes(3, utils::IteratorType::parallel); + SmallVector matmulIteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction + }; + + Value rhsDequant = + rewriter + .create( + loc, emptyDequant.getType(), + ValueRange{rhsExpanded, scales, zps}, emptyDequant, + /*indexingMaps=*/dqIndexingMaps, + /*iteratorTypes=*/dequantIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value w = args[0], scale = args[1], zeroPoint = args[2]; + Value extw = b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create(loc, rewriter.getF16Type(), extw); + Value shifted = b.create(loc, fp_extw, zeroPoint); + Value dqw = b.create(loc, shifted, scale); + b.create(loc, dqw); + }) + .getResult(0); + + Value matmulDequant = + rewriter + .create( + loc, output.getType(), + ValueRange{lhsExpanded, rhsDequant}, output, + /*indexingMaps=*/matIndexingMaps, + /*iteratorTypes=*/matmulIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], out = args[2]; + Value pd = b.create(loc, l, r); + Value ac = b.create(loc, pd, out); + b.create(loc, ac); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, matmulDequant); + return success(); + } +}; +} // namespace + +namespace { +class ConvertCustomQuantOpPass + : public TorchConversion::ConvertCustomQuantOpBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp new file mode 100644 index 000000000..25f325399 --- /dev/null +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -0,0 +1,143 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { +class UnpackQuantizedMatmulWeights + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp, + PatternRewriter &rewriter) const override { + if (!constOp->hasOneUse()) + return failure(); + + OpOperand *use = constOp.getResult().use_begin().getOperand(); + auto op = dyn_cast(use->getOwner()); + if (!op) { + return failure(); + } + if (op.getName().str() != "quant.matmul_rhs_group_quant") { + return failure(); + } + + if (use->getOperandNumber() != 1) { + return failure(); + } + + Value rhs = op.getOperand(1); + Value bitWidth = op.getOperand(4); + + auto getConstantIntegerFromDefiningOp = [](Value operand, + int &extractedInt) { + auto constOp = dyn_cast(operand.getDefiningOp()); + if (!constOp) { + return failure(); + } + extractedInt = constOp.getValue(); + return success(); + }; + int unpackedBitWidth; + if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth))) + return failure(); + + auto rhsType = rhs.getType().dyn_cast(); + if (!rhsType) + return failure(); + + if (!rhsType.hasDtype()) + return failure(); + + Type dType = rhsType.getDtype(); + int dTypeWidth = dType.getIntOrFloatBitWidth(); + if (dTypeWidth == unpackedBitWidth) + return failure(); + + if (!rhsType.hasSizes()) + return failure(); + + SmallVector tensorShape(rhsType.getSizes()); + if (tensorShape.back() == kUnknownSize) + return failure(); + int packRatio = dTypeWidth / unpackedBitWidth; + + tensorShape[tensorShape.size() - 1] *= packRatio; + Type unpackedElementType; + if (dType.isSignedInteger()) + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true); + else + unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false); + ValueTensorType newRhsType = ValueTensorType::get( + rewriter.getContext(), tensorShape, unpackedElementType); + + auto elements = constOp.getValueAttr().dyn_cast(); + if (!elements) + return failure(); + + auto attrType = RankedTensorType::get(tensorShape, unpackedElementType); + + // TODO: Materialize IR that does the conversion from quantized type to + // pure integer type which relys on constant evaluation in backends + auto data = elements.getRawData(); + std::vector newData(data.size() * packRatio, + APInt(unpackedBitWidth, 0)); + for (int i = 0, e = data.size(); i < e; ++i) { + auto el = data[i]; + char mask = (1 << unpackedBitWidth) - 1; + for (int b = 0; b < packRatio; b++) { + newData[i * packRatio + b] = + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); + mask = mask << unpackedBitWidth; + } + } + rewriter.replaceOpWithNewOp( + constOp, newRhsType, + DenseElementsAttr::get(attrType, ArrayRef(newData))); + return success(); + } +}; +} // namespace + +namespace { +class UnpackQuantTensorPass + : public TorchConversion::UnpackQuantTensorBase { + using UnpackQuantTensorBase::UnpackQuantTensorBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::TorchConversion::createUnpackQuantTensorPass() { + return std::make_unique(); +} diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir new file mode 100644 index 000000000..4f72f24e8 --- /dev/null +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -0,0 +1,45 @@ +// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-LABEL: func @forward +func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2],f16> { + %q_rhs = torch.vtensor.literal(dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8> + %scales = torch.vtensor.literal(dense<1.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + %zps = torch.vtensor.literal(dense<0.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + %bit_width = torch.constant.int 8 + %group_size = torch.constant.int 2 + %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,2],f16>, !torch.vtensor<[2,2],ui8>, !torch.vtensor<[2,1,1],f16>, !torch.vtensor<[2,1,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f16> + // CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2],f16> -> tensor<1x1x2xf16> + // CHECK: %[[TENSOR1:.*]] = torch.vtensor.literal(dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8> + // CHECK: %[[QUANT_RHS:.*]] = torch_c.to_builtin_tensor %[[TENSOR1]] : !torch.vtensor<[2,2],ui8> -> tensor<2x2xi8> + // CHECK: %[[TENSOR2:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + // CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> + // CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16> + // CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16> + // CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16> + // CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8> + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16> + // CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16> + // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[CST]] : f16) outs(%[[EMPTY2]] : tensor<1x1x2xf16>) -> tensor<1x1x2xf16> + // CHECK: %[[DEQUANT_RHS:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_RHS]], %[[SCALES]], %[[ZPS]] : tensor<2x1x2xi8>, tensor<2x1x1xf16>, tensor<2x1x1xf16>) outs(%[[EMPTY1]] : tensor<2x1x2xf16>) { + // CHECK-NEXT: ^bb0(%[[WEIGHTS:.*]]: i8, %[[SCALES:.*]]: f16, %[[ZPS:.*]]: f16, %{{.*}}: f16): + // CHECK-NEXT: %[[EXTUI:.*]] = arith.extui %[[WEIGHTS]] : i8 to i32 + // CHECK-NEXT: %[[UITOFP:.*]] = arith.uitofp %[[EXTUI]] : i32 to f16 + // CHECK-NEXT: %[[SUBF:.*]] = arith.subf %[[UITOFP]], %[[ZPS]] : f16 + // CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[SUBF]], %[[SCALES]] : f16 + // CHECK-NEXT: linalg.yield %[[MULF]] : f16 + // CHECK-NEXT: } -> tensor<2x1x2xf16> + // CHECK: %[[MATMUL:.*]] = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[EXPANDED_LHS]], %[[DEQUANT_RHS]] : tensor<1x1x1x2xf16>, tensor<2x1x2xf16>) outs(%[[OUT]] : tensor<1x1x2xf16>) { + // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: f16, %[[OUT:.*]]: f16): + // CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f16 + // CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[MULF]], %[[OUT]] : f16 + // CHECK-NEXT: linalg.yield %[[ADDF]] : f16 + // CHECK-NEXT: } -> tensor<1x1x2xf16> + // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<1x1x2xf16> to tensor<1x1x2xf16> + return %output : !torch.vtensor<[1,1,2],f16> +} diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir new file mode 100644 index 000000000..0ca64ae09 --- /dev/null +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -0,0 +1,13 @@ +// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @forward +func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> { + %q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8> + // CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4> + %scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> + %zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16> + %bit_width = torch.constant.int 4 + %group_size = torch.constant.int 2 + %output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16> + return %output : !torch.vtensor<[1,1,8],f16> +}