mirror of https://github.com/llvm/torch-mlir
Address the comments
parent
b8f361ad35
commit
5a48bf1cad
|
@ -57,6 +57,11 @@ std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();
|
|||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createFinalizingBackendTypeConversionPass();
|
||||
|
||||
// These passes do a one-off conversion of a specific kind of quantized group
|
||||
// matmul as a prototype. Generalized quantized operation handling will likely
|
||||
// obviate them but that are being carried for now in order to unblock progress
|
||||
// on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for
|
||||
// the plan to support a more generalized lowering for these graphs.
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createUnpackQuantTensorPass();
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createConvertCustomQuantOpPass();
|
||||
|
||||
|
|
|
@ -32,16 +32,6 @@ def FinalizingBackendTypeConversion
|
|||
}];
|
||||
}
|
||||
|
||||
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
|
||||
let summary = "Unpack quantized int4 tensor from int8 containter";
|
||||
let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()";
|
||||
}
|
||||
|
||||
def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> {
|
||||
let summary = "Convert torch custom quant op to linalg";
|
||||
let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()";
|
||||
}
|
||||
|
||||
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||
|
@ -58,4 +48,16 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra
|
|||
let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()";
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
// The following passes are for a one-off conversion of a specific kind of quantized group matmul.
|
||||
// They should not be included in default lowering flows until further along.
|
||||
def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> {
|
||||
let summary = "Unpack quantized int4 tensor from int8 containter";
|
||||
let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()";
|
||||
}
|
||||
|
||||
def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> {
|
||||
let summary = "Convert torch custom quant op to linalg";
|
||||
let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()";
|
||||
}
|
||||
#endif // TORCHMLIR_TORCHCONVERSION_PASSES
|
||||
|
|
|
@ -40,28 +40,28 @@ public:
|
|||
return failure();
|
||||
}
|
||||
|
||||
// get inputs: lhs, q_rhs, scales, zps
|
||||
// get inputs: lhs, rhsQuant, scales, zps
|
||||
Value lhs = adaptor.getOperands()[0];
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
if (!lhsType) {
|
||||
return failure();
|
||||
}
|
||||
auto lhsShape = lhsType.getShape();
|
||||
int lhs_reduct_dim_size = lhsShape.back();
|
||||
int lhsReductDimSize = lhsShape.back();
|
||||
|
||||
Value q_rhs = adaptor.getOperands()[1];
|
||||
auto rhsType = q_rhs.getType().cast<RankedTensorType>();
|
||||
Value rhsQuant = adaptor.getOperands()[1];
|
||||
auto rhsType = rhsQuant.getType().cast<RankedTensorType>();
|
||||
if (!rhsType) {
|
||||
return failure();
|
||||
}
|
||||
auto rhsShape = rhsType.getShape();
|
||||
int rhs_reduct_dim_size = rhsShape.back();
|
||||
Type rhs_elementType = rhsType.getElementType();
|
||||
int rhsReductDimSize = rhsShape.back();
|
||||
Type rhsElementType = rhsType.getElementType();
|
||||
|
||||
Value scales = adaptor.getOperands()[2];
|
||||
Value zps = adaptor.getOperands()[3];
|
||||
Value unpacked_type_width = adaptor.getOperands()[4];
|
||||
Value group_size = adaptor.getOperands()[5];
|
||||
Value unpackedTypeWidth = adaptor.getOperands()[4];
|
||||
Value groupSize = adaptor.getOperands()[5];
|
||||
|
||||
auto getConstantIntegerFromDefiningOp = [](Value operand,
|
||||
int &extractedInt) {
|
||||
|
@ -79,14 +79,14 @@ public:
|
|||
};
|
||||
|
||||
int gs;
|
||||
if (failed(getConstantIntegerFromDefiningOp(group_size, gs))) {
|
||||
if (failed(getConstantIntegerFromDefiningOp(groupSize, gs))) {
|
||||
return failure();
|
||||
}
|
||||
int unpackedBitWidth;
|
||||
if (failed(getConstantIntegerFromDefiningOp(unpacked_type_width, unpackedBitWidth))) {
|
||||
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) {
|
||||
return failure();
|
||||
}
|
||||
if (unpackedBitWidth != rhs_elementType.getIntOrFloatBitWidth()) {
|
||||
if (unpackedBitWidth != rhsElementType.getIntOrFloatBitWidth()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -100,24 +100,24 @@ public:
|
|||
Type elementType = resultType.getElementType();
|
||||
|
||||
// expand lhs
|
||||
std::vector<int64_t> lhs_expandedShape = {lhsShape[0], lhsShape[1],
|
||||
lhs_reduct_dim_size / gs, gs};
|
||||
RankedTensorType lhs_expandedType = RankedTensorType::get(lhs_expandedShape, elementType);
|
||||
SmallVector<ReassociationIndices, 4> lhs_reassociation = {{0}, {1}, {2, 3}};
|
||||
Value expanded_lhs = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, lhs_expandedType, lhs, lhs_reassociation);
|
||||
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
|
||||
lhsReductDimSize / gs, gs};
|
||||
RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType);
|
||||
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
|
||||
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, lhsExpandedType, lhs, lhsReassociation);
|
||||
|
||||
// expand rhs
|
||||
std::vector<int64_t> expandedShape = {rhsShape[0], rhs_reduct_dim_size/gs, gs};
|
||||
RankedTensorType expandedType = RankedTensorType::get(expandedShape, rhs_elementType);
|
||||
SmallVector<ReassociationIndices, 4> reassociation = {{0}, {1, 2}};
|
||||
Value expanded_rhs = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, expandedType, q_rhs, reassociation);
|
||||
Value cst_0 = rewriter.create<arith::ConstantOp>(
|
||||
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs};
|
||||
RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType);
|
||||
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
|
||||
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, rhsExpandedType, rhsQuant, rhsReassociation);
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 0.0));
|
||||
|
||||
Value dq_empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, expandedShape, elementType);
|
||||
Value emptyDequant = rewriter.create<tensor::EmptyOp>(
|
||||
loc, rhsExpandedShape, elementType);
|
||||
SmallVector<Value> dynDims;
|
||||
for (int i = 0; i < lhsType.getRank(); i++) {
|
||||
if (lhsType.isDynamicDim(i)) {
|
||||
|
@ -127,7 +127,7 @@ public:
|
|||
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, resultShape, elementType, dynDims);
|
||||
Value output = rewriter.create<linalg::FillOp>(
|
||||
loc, cst_0, empty).getResult(0);
|
||||
loc, cst0, empty).getResult(0);
|
||||
|
||||
AffineExpr d0, d1, d2, d3, d4;
|
||||
bindDims(getContext(), d0, d1, d2, d3, d4);
|
||||
|
@ -137,23 +137,23 @@ public:
|
|||
auto map2 = AffineMap::get(5, 0, {d0, d1, d3, d4}, rewriter.getContext());
|
||||
auto map3 = AffineMap::get(5, 0, {d2, d3, d4}, rewriter.getContext());
|
||||
auto map4 = AffineMap::get(5, 0, {d0, d1, d2}, rewriter.getContext());
|
||||
SmallVector<AffineMap, 4> dq_indexingMaps = {map, map1, map1, map};
|
||||
SmallVector<AffineMap, 4> mat_indexingMaps = {map2, map3, map4};
|
||||
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
|
||||
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
|
||||
|
||||
SmallVector<utils::IteratorType> dq_iteratorTypes(3, utils::IteratorType::parallel);
|
||||
SmallVector<utils::IteratorType> mat_iteratorTypes = {
|
||||
SmallVector<utils::IteratorType> dequantIteratorTypes(3, utils::IteratorType::parallel);
|
||||
SmallVector<utils::IteratorType> matmulIteratorTypes = {
|
||||
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
||||
utils::IteratorType::parallel, utils::IteratorType::reduction,
|
||||
utils::IteratorType::reduction
|
||||
};
|
||||
|
||||
Value dq_rhs =
|
||||
Value rhsDequant =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, dq_empty.getType(),
|
||||
ValueRange{expanded_rhs, scales, zps}, dq_empty,
|
||||
/*indexingMaps=*/dq_indexingMaps,
|
||||
/*iteratorTypes=*/dq_iteratorTypes,
|
||||
loc, emptyDequant.getType(),
|
||||
ValueRange{rhsExpanded, scales, zps}, emptyDequant,
|
||||
/*indexingMaps=*/dqIndexingMaps,
|
||||
/*iteratorTypes=*/dequantIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value w = args[0], scale = args[1], zeroPoint = args[2];
|
||||
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
|
||||
|
@ -164,13 +164,13 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
|
||||
Value quantMat =
|
||||
Value matmulDequant =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, output.getType(),
|
||||
ValueRange{expanded_lhs, dq_rhs}, output,
|
||||
/*indexingMaps=*/mat_indexingMaps,
|
||||
/*iteratorTypes=*/mat_iteratorTypes,
|
||||
ValueRange{lhsExpanded, rhsDequant}, output,
|
||||
/*indexingMaps=*/matIndexingMaps,
|
||||
/*iteratorTypes=*/matmulIteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value l = args[0], r = args[1], out = args[2];
|
||||
Value pd = b.create<arith::MulFOp>(loc, l, r);
|
||||
|
@ -179,7 +179,7 @@ public:
|
|||
})
|
||||
.getResult(0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, quantMat);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, matmulDequant);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -32,16 +32,17 @@ public:
|
|||
|
||||
OpOperand *use = constOp.getResult().use_begin().getOperand();
|
||||
auto op = dyn_cast<OperatorOp>(use->getOwner());
|
||||
if (!op)
|
||||
if (!op) {
|
||||
return failure();
|
||||
|
||||
if (use->getOperandNumber() != 1)
|
||||
return failure();
|
||||
|
||||
}
|
||||
if (op.getName().str() != "quant.matmul_rhs_group_quant") {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (use->getOperandNumber() != 1) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value rhs = op.getOperand(1);
|
||||
Value bitWidth = op.getOperand(4);
|
||||
|
||||
|
@ -93,7 +94,8 @@ public:
|
|||
|
||||
auto attrType = RankedTensorType::get(tensorShape, unpackedElementType);
|
||||
|
||||
// This is terrible but idk what else to do.
|
||||
// TODO: Materialize IR that does the conversion from quantized type to
|
||||
// pure integer type which relys on constant evaluation in backends
|
||||
auto data = elements.getRawData();
|
||||
std::vector<APInt> newData(data.size() * packRatio,
|
||||
APInt(unpackedBitWidth, 0));
|
||||
|
|
|
@ -1,12 +1,45 @@
|
|||
// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -verify-diagnostics | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
|
||||
// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
|
||||
// CHECK: #map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
|
||||
// CHECK: #map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @forward
|
||||
func.func @forward(%arg0: !torch.vtensor<[1,1,2],f16>) -> !torch.vtensor<[1,1,2],f16> {
|
||||
%q_rhs = torch.vtensor.literal(dense<[[0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
|
||||
%scales = torch.vtensor.literal(dense<[[[1.0]], [[1.0]]]> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
%zps = torch.vtensor.literal(dense<[[[0.0]], [[0.0]]]> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
%scales = torch.vtensor.literal(dense<1.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
%zps = torch.vtensor.literal(dense<0.0> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
%bit_width = torch.constant.int 8
|
||||
%group_size = torch.constant.int 2
|
||||
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,2],f16>, !torch.vtensor<[2,2],ui8>, !torch.vtensor<[2,1,1],f16>, !torch.vtensor<[2,1,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,2],f16>
|
||||
// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2],f16> -> tensor<1x1x2xf16>
|
||||
// CHECK: %[[TENSOR1:.*]] = torch.vtensor.literal(dense<{{\[\[}}0, 1], [2, 3]]> : tensor<2x2xui8>) : !torch.vtensor<[2,2],ui8>
|
||||
// CHECK: %[[QUANT_RHS:.*]] = torch_c.to_builtin_tensor %[[TENSOR1]] : !torch.vtensor<[2,2],ui8> -> tensor<2x2xi8>
|
||||
// CHECK: %[[TENSOR2:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
// CHECK: %[[SCALES:.*]] = torch_c.to_builtin_tensor %[[TENSOR2]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
|
||||
// CHECK: %[[TENSOR3:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x1x1xf16>) : !torch.vtensor<[2,1,1],f16>
|
||||
// CHECK: %[[ZPS:.*]] = torch_c.to_builtin_tensor %[[TENSOR3]] : !torch.vtensor<[2,1,1],f16> -> tensor<2x1x1xf16>
|
||||
// CHECK: %[[EXPANDED_LHS:.*]] = tensor.expand_shape %[[LHS]] {{\[\[}}0], [1], [2, 3]] : tensor<1x1x2xf16> into tensor<1x1x1x2xf16>
|
||||
// CHECK: %[[EXPANDED_RHS:.*]] = tensor.expand_shape %[[QUANT_RHS]] {{\[\[}}0], [1, 2]] : tensor<2x2xi8> into tensor<2x1x2xi8>
|
||||
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f16
|
||||
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<2x1x2xf16>
|
||||
// CHECK: %[[EMPTY2:.*]] = tensor.empty() : tensor<1x1x2xf16>
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill ins(%[[CST]] : f16) outs(%[[EMPTY2]] : tensor<1x1x2xf16>) -> tensor<1x1x2xf16>
|
||||
// CHECK: %[[DEQUANT_RHS:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[EXPANDED_RHS]], %[[SCALES]], %[[ZPS]] : tensor<2x1x2xi8>, tensor<2x1x1xf16>, tensor<2x1x1xf16>) outs(%[[EMPTY1]] : tensor<2x1x2xf16>) {
|
||||
// CHECK-NEXT: ^bb0(%[[WEIGHTS:.*]]: i8, %[[SCALES:.*]]: f16, %[[ZPS:.*]]: f16, %{{.*}}: f16):
|
||||
// CHECK-NEXT: %[[EXTUI:.*]] = arith.extui %[[WEIGHTS]] : i8 to i32
|
||||
// CHECK-NEXT: %[[UITOFP:.*]] = arith.uitofp %[[EXTUI]] : i32 to f16
|
||||
// CHECK-NEXT: %[[SUBF:.*]] = arith.subf %[[UITOFP]], %[[ZPS]] : f16
|
||||
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[SUBF]], %[[SCALES]] : f16
|
||||
// CHECK-NEXT: linalg.yield %[[MULF]] : f16
|
||||
// CHECK-NEXT: } -> tensor<2x1x2xf16>
|
||||
// CHECK: %[[MATMUL:.*]] = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[EXPANDED_LHS]], %[[DEQUANT_RHS]] : tensor<1x1x1x2xf16>, tensor<2x1x2xf16>) outs(%[[OUT]] : tensor<1x1x2xf16>) {
|
||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: f16, %[[OUT:.*]]: f16):
|
||||
// CHECK-NEXT: %[[MULF:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f16
|
||||
// CHECK-NEXT: %[[ADDF:.*]] = arith.addf %[[MULF]], %[[OUT]] : f16
|
||||
// CHECK-NEXT: linalg.yield %[[ADDF]] : f16
|
||||
// CHECK-NEXT: } -> tensor<1x1x2xf16>
|
||||
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<1x1x2xf16> to tensor<1x1x2xf16>
|
||||
return %output : !torch.vtensor<[1,1,2],f16>
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor | FileCheck %s
|
||||
// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @forward
|
||||
func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> {
|
||||
%q_rhs = torch.vtensor.literal(dense<[[57, 128, 249, 244], [7, 243, 27, 15], [1, 2, 159, 71], [159, 253, 160, 231], [248, 224, 191, 228], [96, 15, 158, 220], [240, 250, 47, 208], [127, 192, 239, 176]]> : tensor<8x4xui8>) : !torch.vtensor<[8,4],ui8>
|
||||
%scales = torch.vtensor.literal(dense<[[[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0], [1.0]]]> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||
%zps = torch.vtensor.literal(dense<[[[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0], [0.0]]]> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||
// CHECK: %[[C0:.*]] = torch.vtensor.literal(dense<{{\[\[}}9, 3, 0, 8, 9, 15, 4, 15], [7, 0, 3, 15, 11, 1, 15, 0], [1, 0, 2, 0, 15, 9, 7, 4], [15, 9, 13, 15, 0, 10, 7, 14], [8, 15, 0, 14, 15, 11, 4, 14], [0, 6, 15, 0, 14, 9, 12, 13], [0, 15, 10, 15, 15, 2, 0, 13], [15, 7, 0, 12, 15, 14, 0, 11]]> : tensor<8x8xui4>) : !torch.vtensor<[8,8],ui4>
|
||||
%scales = torch.vtensor.literal(dense<1.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||
%zps = torch.vtensor.literal(dense<0.0> : tensor<8x4x1xf16>) : !torch.vtensor<[8,4,1],f16>
|
||||
%bit_width = torch.constant.int 4
|
||||
%group_size = torch.constant.int 2
|
||||
%output = torch.operator "quant.matmul_rhs_group_quant"(%arg0, %q_rhs, %scales, %zps, %bit_width, %group_size) : (!torch.vtensor<[1,1,8],f16>, !torch.vtensor<[8,4],ui8>, !torch.vtensor<[8,4,1],f16>, !torch.vtensor<[8,4,1],f16>, !torch.int, !torch.int) -> !torch.vtensor<[1,1,8],f16>
|
||||
|
|
Loading…
Reference in New Issue