mirror of https://github.com/llvm/torch-mlir
[torch] `torch.dequantize` for per channel tensors to` linalg` (#2769)
Support a lowering for dequantization for per channel tensors from `torch` dialect to a linalg decomposition. Tested via a numerical `torch` test.pull/2747/head
parent
0aed231e21
commit
2ef228328f
|
@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scales,
|
||||
AnyTorchTensorType:$zero_points,
|
||||
Torch_IntType:$axis,
|
||||
Torch_IntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||
}
|
||||
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$scale,
|
||||
AnyTorchTensorType:$zero_point,
|
||||
Torch_IntType:$axis
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
auto makeQTensor =
|
||||
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
|
||||
if (!makeQTensor) {
|
||||
op->emitError(
|
||||
op->emitWarning(
|
||||
"unimplemented: dequantizing tensor of unknown scale / zero-point");
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -2221,16 +2221,109 @@ public:
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertMakePerTensorQuantizedTensorOp
|
||||
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
|
||||
class ConvertDequantizePerChannel
|
||||
: public OpConversionPattern<AtenDequantizeSelfOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RankedTensorType resultType = getTypeConverter()
|
||||
->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
auto loc = op.getLoc();
|
||||
auto qoperand = op.getOperand();
|
||||
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
|
||||
if (!make) {
|
||||
llvm::errs() << "Did not find make per channel\n";
|
||||
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
|
||||
}
|
||||
|
||||
auto converter = getTypeConverter();
|
||||
auto operand = make.getOperand(0);
|
||||
auto scale = make.getScale();
|
||||
auto zeropoint = make.getZeroPoint();
|
||||
auto axis = make.getAxis();
|
||||
|
||||
IntegerAttr axisAttr;
|
||||
if (!matchPattern(axis, m_Constant(&axisAttr))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
|
||||
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
|
||||
operand = converter->materializeTargetConversion(
|
||||
rewriter, loc, converter->convertType(operand.getType()), operand);
|
||||
scale = converter->materializeTargetConversion(
|
||||
rewriter, loc, converter->convertType(scale.getType()), scale);
|
||||
zeropoint = converter->materializeTargetConversion(
|
||||
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
|
||||
|
||||
auto resultType = converter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
llvm::SmallVector<Value> dynSizes;
|
||||
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
|
||||
if (ShapedType::isDynamic(dim)) {
|
||||
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, operand, index));
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<utils::IteratorType> iterators(
|
||||
resultType.getRank(), utils::IteratorType::parallel);
|
||||
llvm::SmallVector<AffineMap> maps(
|
||||
4, {rewriter.getMultiDimIdentityMap(resultType.getRank())});
|
||||
auto broadcastMap = AffineMap::get(
|
||||
resultType.getRank(), /*symbolCount=*/0,
|
||||
{rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext());
|
||||
maps[1] = broadcastMap;
|
||||
maps[2] = broadcastMap;
|
||||
|
||||
auto empty =
|
||||
rewriter.create<tensor::EmptyOp>(op.getLoc(), resultType, dynSizes);
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, resultType, ValueRange{operand, scale, zeropoint},
|
||||
ValueRange{empty}, maps, iterators,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value operand = args[0];
|
||||
Value scale = args[1];
|
||||
Value zeropoint = args[2];
|
||||
if (operandDTy.isUnsignedInteger(8)) {
|
||||
operand = b.create<arith::ExtUIOp>(loc, b.getI32Type(), operand);
|
||||
} else if (operandDTy.isSignedInteger(8)) {
|
||||
operand = b.create<arith::ExtSIOp>(loc, b.getI32Type(), operand);
|
||||
}
|
||||
|
||||
if (zeropointDTy.isUnsignedInteger(8)) {
|
||||
zeropoint =
|
||||
b.create<arith::ExtUIOp>(loc, b.getI32Type(), zeropoint);
|
||||
} else if (zeropointDTy.isSignedInteger(8)) {
|
||||
zeropoint =
|
||||
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
|
||||
}
|
||||
|
||||
Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
|
||||
Value fp =
|
||||
rewriter.create<arith::SIToFPOp>(loc, args[3].getType(), sub);
|
||||
Value mul = rewriter.create<arith::MulFOp>(loc, fp, scale);
|
||||
b.create<linalg::YieldOp>(loc, mul);
|
||||
});
|
||||
rewriter.replaceOp(op, linalgOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename OpTy>
|
||||
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
|
||||
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||
using OpAdaptor = typename OpTy::Adaptor;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(OpTy op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = this->getTypeConverter();
|
||||
RankedTensorType resultType = cast<RankedTensorType>(
|
||||
converter->convertType(op->getResult(0).getType()));
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
|
||||
adaptor.getSelf());
|
||||
return success();
|
||||
|
@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
target.addIllegalOp<TensorStaticInfoCastOp>();
|
||||
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntReprOp>();
|
||||
patterns.add<ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
|
||||
patterns.add<ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<Aten_MakePerChannelQuantizedTensorOp>();
|
||||
patterns.add<ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
|
||||
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
|
||||
" return %arg4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" %int14 = torch.constant.int 14\n"
|
||||
" %int12 = torch.constant.int 12\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int13 = torch.constant.int 13\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int13 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int12 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int14 : !torch.int\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n"
|
||||
" %int14 = torch.constant.int 14\n"
|
||||
" %int12 = torch.constant.int 12\n"
|
||||
|
|
|
@ -39,6 +39,20 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
|||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape>
|
||||
compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self,
|
||||
const at::Tensor &scale,
|
||||
const at::Tensor &zero_point,
|
||||
int64_t axis) {
|
||||
if (self.scalar_type() == at::kChar)
|
||||
return {Shape(at::kQInt8, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kByte)
|
||||
return {Shape(at::kQUInt8, self.sizes().vec())};
|
||||
if (self.scalar_type() == at::kInt)
|
||||
return {Shape(at::kQInt32, self.sizes().vec())};
|
||||
assert(false);
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
|
||||
const at::Tensor &self, double scale, int64_t zero_point) {
|
||||
if (self.scalar_type() == at::kChar)
|
||||
|
@ -75,6 +89,12 @@ std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
|||
return {Shape(at::kBool, self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
|
||||
const at::Tensor &self, const at::Tensor &scales,
|
||||
const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) {
|
||||
return {Shape(dtype, self.sizes().vec())};
|
||||
}
|
||||
|
||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||
const at::Tensor& self, at::IntArrayRef kernel_size,
|
||||
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
|
||||
|
|
|
@ -313,6 +313,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
"GroupNormNoWeightAndBiasModule_basic",
|
||||
|
||||
# Dynamo does not support tracing quantized tensors
|
||||
"ElementwiseDequantizePerChannelModule_basic",
|
||||
"ElementwiseDequantizePerTensorModule_basic",
|
||||
"ElementwiseQuantizePerTensorModule_basic",
|
||||
"AtenMmQuint8_basic",
|
||||
|
|
|
@ -251,6 +251,9 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]:
|
|||
def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -263,6 +266,9 @@ def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]:
|
|||
def aten〇int_repr〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -4280,6 +4286,9 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int
|
|||
return a_dtype
|
||||
|
||||
|
||||
def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int:
|
||||
return dtype
|
||||
|
||||
def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int:
|
||||
return dtype
|
||||
|
||||
|
@ -4297,6 +4306,14 @@ def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return torch.int8
|
||||
return torch.int32
|
||||
|
||||
def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if (self_dtype == torch.uint8):
|
||||
return torch.quint8
|
||||
if (self_dtype == torch.int8):
|
||||
return torch.qint8
|
||||
return torch.qint32
|
||||
|
||||
def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if (self_dtype == torch.uint8):
|
||||
|
|
|
@ -820,10 +820,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
|
||||
|
||||
# quantized ops
|
||||
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)")
|
||||
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
|
||||
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
|
||||
emit("aten::int_repr : (Tensor) -> (Tensor)")
|
||||
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
|
||||
|
||||
# ==========================================================================
|
||||
|
|
|
@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseDequantizePerChannelModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int8, True),
|
||||
([4], torch.int8, True),
|
||||
([4], torch.float, True),
|
||||
])
|
||||
def forward(self, x, zeropoint, scale):
|
||||
qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1)
|
||||
qx = torch.dequantize(qx)
|
||||
return qx
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule())
|
||||
def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||
tu.randint(4, low=-128, high=127).to(torch.int8),
|
||||
tu.rand(4)
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class GluStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue