Prototype passes for lowering quantized group matmul (#2402)

* Support brevitas custom op (#2320)

* f16 change for brevitas

* Adapt the change of brevitas quant custom op name

* Add unit tests

* Make brevitas conversions isolated

* Address the comments

---------

Co-authored-by: dan <danimal197@gmail.com>
pull/2426/head
jinchen62 2023-08-29 21:25:45 -07:00 committed by GitHub
parent c42d2beb6e
commit 1682b540bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 450 additions and 2 deletions

View File

@ -57,6 +57,14 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createFinalizingBackendTypeConversionPass();
// These passes do a one-off conversion of a specific kind of quantized group
// matmul as a prototype. Generalized quantized operation handling will likely
// obviate them but that are being carried for now in order to unblock progress
// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for
// the plan to support a more generalized lowering for these graphs.
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyLinalgOnTensorsBackendContractPass();

View File

@ -48,4 +48,16 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO
// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
// They should not be included in default lowering flows until further along.
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
let summary = "Unpack quantized int4 tensor from int8 containter";
let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()";
}
def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> {
let summary = "Convert torch custom quant op to linalg";
let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()";
}
#endif // TORCHMLIR_TORCHCONVERSION_PASSES

View File

@ -194,13 +194,13 @@ static bool isValidTorchDtype(Type dtype) {
if (type.isSignless() && type.getWidth() == 1)
return true;
if (type.isSigned()) {
for (unsigned width : {8, 16, 32, 64}) {
for (unsigned width : {4, 8, 16, 32, 64}) {
if (type.getWidth() == width)
return true;
}
}
if (type.isUnsigned()) {
return type.getWidth() == 8;
return type.getWidth() == 8 || type.getWidth() == 4;
}
}
return false;

View File

@ -25,6 +25,8 @@ add_mlir_library(TorchMLIRTorchConversionPasses
BackendTypeConversion.cpp
BackendTypeConversionPasses.cpp
Passes.cpp
ConvertCustomQuantOp.cpp
UnpackQuantTensor.cpp
VerifyLinalgOnTensorsBackendContract.cpp
VerifyTosaBackendContract.cpp
VerifyStablehloBackendContract.cpp

View File

@ -0,0 +1,225 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class ConvertCustomQuantizedMatmulOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getName().str() != "quant.matmul_rhs_group_quant") {
return failure();
}
Location loc = op->getLoc();
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
return failure();
}
// get inputs: lhs, rhsQuant, scales, zps
Value lhs = adaptor.getOperands()[0];
auto lhsType = lhs.getType().cast<RankedTensorType>();
if (!lhsType) {
return failure();
}
auto lhsShape = lhsType.getShape();
int lhsReductDimSize = lhsShape.back();
Value rhsQuant = adaptor.getOperands()[1];
auto rhsType = rhsQuant.getType().cast<RankedTensorType>();
if (!rhsType) {
return failure();
}
auto rhsShape = rhsType.getShape();
int rhsReductDimSize = rhsShape.back();
Type rhsElementType = rhsType.getElementType();
Value scales = adaptor.getOperands()[2];
Value zps = adaptor.getOperands()[3];
Value unpackedTypeWidth = adaptor.getOperands()[4];
Value groupSize = adaptor.getOperands()[5];
auto getConstantIntegerFromDefiningOp = [](Value operand,
int &extractedInt) {
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
if (!castOp) {
return failure();
}
auto constOp =
dyn_cast<Torch::ConstantIntOp>(castOp.getOperand(0).getDefiningOp());
if (!constOp) {
return failure();
}
extractedInt = constOp.getValue();
return success();
};
int gs;
if (failed(getConstantIntegerFromDefiningOp(groupSize, gs))) {
return failure();
}
int unpackedBitWidth;
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) {
return failure();
}
if (unpackedBitWidth != rhsElementType.getIntOrFloatBitWidth()) {
return failure();
}
// get outputs
Type newResultType = getTypeConverter()->convertType(op.getType(0));
auto resultType = newResultType.cast<RankedTensorType>();
if (!resultType) {
return failure();
}
auto resultShape = resultType.getShape();
Type elementType = resultType.getElementType();
// expand lhs
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
lhsReductDimSize / gs, gs};
RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType);
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
loc, lhsExpandedType, lhs, lhsReassociation);
// expand rhs
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs};
RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType);
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
loc, rhsExpandedType, rhsQuant, rhsReassociation);
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value emptyDequant = rewriter.create<tensor::EmptyOp>(
loc, rhsExpandedShape, elementType);
SmallVector<Value> dynDims;
for (int i = 0; i < lhsType.getRank(); i++) {
if (lhsType.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
}
}
Value empty = rewriter.create<tensor::EmptyOp>(
loc, resultShape, elementType, dynDims);
Value output = rewriter.create<linalg::FillOp>(
loc, cst0, empty).getResult(0);
AffineExpr d0, d1, d2, d3, d4;
bindDims(getContext(), d0, d1, d2, d3, d4);
auto c0 = rewriter.getAffineConstantExpr(0);
auto map = AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext());
auto map1 = AffineMap::get(3, 0, {d0, d1, c0}, rewriter.getContext());
auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext());
auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext());
auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext());
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
SmallVector<utils::IteratorType> dequantIteratorTypes(3, utils::IteratorType::parallel);
SmallVector<utils::IteratorType> matmulIteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::parallel, utils::IteratorType::reduction,
utils::IteratorType::reduction
};
Value rhsDequant =
rewriter
.create<linalg::GenericOp>(
loc, emptyDequant.getType(),
ValueRange{rhsExpanded, scales, zps}, emptyDequant,
/*indexingMaps=*/dqIndexingMaps,
/*iteratorTypes=*/dequantIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value w = args[0], scale = args[1], zeroPoint = args[2];
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
Value fp_extw = b.create<arith::UIToFPOp>(loc, rewriter.getF16Type(), extw);
Value shifted = b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
b.create<linalg::YieldOp>(loc, dqw);
})
.getResult(0);
Value matmulDequant =
rewriter
.create<linalg::GenericOp>(
loc, output.getType(),
ValueRange{lhsExpanded, rhsDequant}, output,
/*indexingMaps=*/matIndexingMaps,
/*iteratorTypes=*/matmulIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value l = args[0], r = args[1], out = args[2];
Value pd = b.create<arith::MulFOp>(loc, l, r);
Value ac = b.create<arith::AddFOp>(loc, pd, out);
b.create<linalg::YieldOp>(loc, ac);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmulDequant);
return success();
}
};
} // namespace
namespace {
class ConvertCustomQuantOpPass
: public TorchConversion::ConvertCustomQuantOpBase<ConvertCustomQuantOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
registry.insert<func::FuncDialect>();
registry.insert<linalg::LinalgDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<Torch::TorchDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
tensor::TensorDialect, arith::ArithDialect,
Torch::TorchDialect>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
RewritePatternSet patterns(context);
target.addIllegalOp<OperatorOp>();
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::TorchConversion::createConvertCustomQuantOpPass() {
return std::make_unique<ConvertCustomQuantOpPass>();
}

View File

@ -0,0 +1,143 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace {
class UnpackQuantizedMatmulWeights
: public OpRewritePattern<ValueTensorLiteralOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ValueTensorLiteralOp constOp,
PatternRewriter &rewriter) const override {
if (!constOp->hasOneUse())
return failure();
OpOperand *use = constOp.getResult().use_begin().getOperand();
auto op = dyn_cast<OperatorOp>(use->getOwner());
if (!op) {
return failure();
}
if (op.getName().str() != "quant.matmul_rhs_group_quant") {
return failure();
}
if (use->getOperandNumber() != 1) {
return failure();
}
Value rhs = op.getOperand(1);
Value bitWidth = op.getOperand(4);
auto getConstantIntegerFromDefiningOp = [](Value operand,
int &extractedInt) {
auto constOp = dyn_cast<Torch::ConstantIntOp>(operand.getDefiningOp());
if (!constOp) {
return failure();
}
extractedInt = constOp.getValue();
return success();
};
int unpackedBitWidth;
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
return failure();
auto rhsType = rhs.getType().dyn_cast<ValueTensorType>();
if (!rhsType)
return failure();
if (!rhsType.hasDtype())
return failure();
Type dType = rhsType.getDtype();
int dTypeWidth = dType.getIntOrFloatBitWidth();
if (dTypeWidth == unpackedBitWidth)
return failure();
if (!rhsType.hasSizes())
return failure();
SmallVector<int64_t> tensorShape(rhsType.getSizes());
if (tensorShape.back() == kUnknownSize)
return failure();
int packRatio = dTypeWidth / unpackedBitWidth;
tensorShape[tensorShape.size() - 1] *= packRatio;
Type unpackedElementType;
if (dType.isSignedInteger())
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, true);
else
unpackedElementType = rewriter.getIntegerType(unpackedBitWidth, false);
ValueTensorType newRhsType = ValueTensorType::get(
rewriter.getContext(), tensorShape, unpackedElementType);
auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>();
if (!elements)
return failure();
auto attrType = RankedTensorType::get(tensorShape, unpackedElementType);
// TODO: Materialize IR that does the conversion from quantized type to
// pure integer type which relys on constant evaluation in backends
auto data = elements.getRawData();
std::vector<APInt> newData(data.size() * packRatio,
APInt(unpackedBitWidth, 0));
for (int i = 0, e = data.size(); i < e; ++i) {
auto el = data[i];
char mask = (1 << unpackedBitWidth) - 1;
for (int b = 0; b < packRatio; b++) {
newData[i * packRatio + b] =
APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b));
mask = mask << unpackedBitWidth;
}
}
rewriter.replaceOpWithNewOp<ValueTensorLiteralOp>(
constOp, newRhsType,
DenseElementsAttr::get(attrType, ArrayRef<APInt>(newData)));
return success();
}
};
} // namespace
namespace {
class UnpackQuantTensorPass
: public TorchConversion::UnpackQuantTensorBase<UnpackQuantTensorPass> {
using UnpackQuantTensorBase<UnpackQuantTensorPass>::UnpackQuantTensorBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();
registry.insert<Torch::TorchDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<UnpackQuantizedMatmulWeights>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::TorchConversion::createUnpackQuantTensorPass() {
return std::make_unique<UnpackQuantTensorPass>();
}

View File

@ -0,0 +1,45 @@
// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func @forward
func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2],f16> {
%q_rhs = torch.vtensor.literal(dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
%scales = torch.vtensor.literal(dense<1.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
%zps = torch.vtensor.literal(dense<0.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
%bit_width = torch.constant.int 8
%group_size = torch.constant.int 2
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,2],f16>, !torch.vtensor<[2,2],ui8>, !torch.vtensor<[2,1,1],f16>, !torch.vtensor<[2,1,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f16>
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2],f16> -> tensor<1x1x2xf16>
// CHECK: %[[TENSOR1:.*]] = torch.vtensor.literal(dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
// CHECK: %[[QUANT_RHS:.*]] = torch_c.to_builtin_tensor %[[TENSOR1]] : !torch.vtensor<[2,2],ui8> -> tensor<2x2xi8>
// CHECK: %[[TENSOR2:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
// CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
// CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
// CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
// CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16>
// CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8>
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16>
// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16>
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[CST]] : f16) outs(%[[EMPTY2]] : tensor<1x1x2xf16>) -> tensor<1x1x2xf16>
// CHECK: %[[DEQUANT_RHS:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_RHS]], %[[SCALES]], %[[ZPS]] : tensor<2x1x2xi8>, tensor<2x1x1xf16>, tensor<2x1x1xf16>) outs(%[[EMPTY1]] : tensor<2x1x2xf16>) {
// CHECK-NEXT: ^bb0(%[[WEIGHTS:.*]]: i8, %[[SCALES:.*]]: f16, %[[ZPS:.*]]: f16, %{{.*}}: f16):
// CHECK-NEXT: %[[EXTUI:.*]] = arith.extui %[[WEIGHTS]] : i8 to i32
// CHECK-NEXT: %[[UITOFP:.*]] = arith.uitofp %[[EXTUI]] : i32 to f16
// CHECK-NEXT: %[[SUBF:.*]] = arith.subf %[[UITOFP]], %[[ZPS]] : f16
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[SUBF]], %[[SCALES]] : f16
// CHECK-NEXT: linalg.yield %[[MULF]] : f16
// CHECK-NEXT: } -> tensor<2x1x2xf16>
// CHECK: %[[MATMUL:.*]] = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[EXPANDED_LHS]], %[[DEQUANT_RHS]] : tensor<1x1x1x2xf16>, tensor<2x1x2xf16>) outs(%[[OUT]] : tensor<1x1x2xf16>) {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: f16, %[[OUT:.*]]: f16):
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f16
// CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[MULF]], %[[OUT]] : f16
// CHECK-NEXT: linalg.yield %[[ADDF]] : f16
// CHECK-NEXT: } -> tensor<1x1x2xf16>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<1x1x2xf16> to tensor<1x1x2xf16>
return %output : !torch.vtensor<[1,1,2],f16>
}

View File

@ -0,0 +1,13 @@
// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @forward
func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {
%q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8>
// CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4>
%scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
%zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
%bit_width = torch.constant.int 4
%group_size = torch.constant.int 2
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
return %output : !torch.vtensor<[1,1,8],f16>
}