[torch] `torch.dequantize` for per channel tensors to` linalg` (#2769)

Support a lowering for dequantization for per channel tensors from
`torch` dialect to a linalg decomposition. Tested via a numerical
`torch` test.
pull/2747/head
Rob Suderman 2024-01-25 16:40:21 -08:00 committed by GitHub
parent 0aed231e21
commit 2ef228328f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 258 additions and 8 deletions

View File

@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
}];
}
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scales,
AnyTorchTensorType:$zero_points,
Torch_IntType:$axis,
Torch_IntType:$dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [
}];
}
def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
Torch_IntType:$axis
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
auto makeQTensor =
qtensor.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
if (!makeQTensor) {
op->emitError(
op->emitWarning(
"unimplemented: dequantizing tensor of unknown scale / zero-point");
return nullptr;
}
@ -2221,16 +2221,109 @@ public:
} // namespace
namespace {
class ConvertMakePerTensorQuantizedTensorOp
: public OpConversionPattern<Aten_MakePerTensorQuantizedTensorOp> {
class ConvertDequantizePerChannel
: public OpConversionPattern<AtenDequantizeSelfOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor,
matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto loc = op.getLoc();
auto qoperand = op.getOperand();
auto make = qoperand.getDefiningOp<Aten_MakePerChannelQuantizedTensorOp>();
if (!make) {
llvm::errs() << "Did not find make per channel\n";
return rewriter.notifyMatchFailure(op, "did not find per channel qint");
}
auto converter = getTypeConverter();
auto operand = make.getOperand(0);
auto scale = make.getScale();
auto zeropoint = make.getZeroPoint();
auto axis = make.getAxis();
IntegerAttr axisAttr;
if (!matchPattern(axis, m_Constant(&axisAttr))) {
return failure();
}
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
operand = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(operand.getType()), operand);
scale = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(scale.getType()), scale);
zeropoint = converter->materializeTargetConversion(
rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint);
auto resultType = converter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
llvm::SmallVector<Value> dynSizes;
for (auto [index, dim] : llvm::enumerate(resultType.getShape())) {
if (ShapedType::isDynamic(dim)) {
dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, operand, index));
}
}
llvm::SmallVector<utils::IteratorType> iterators(
resultType.getRank(), utils::IteratorType::parallel);
llvm::SmallVector<AffineMap> maps(
4, {rewriter.getMultiDimIdentityMap(resultType.getRank())});
auto broadcastMap = AffineMap::get(
resultType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext());
maps[1] = broadcastMap;
maps[2] = broadcastMap;
auto empty =
rewriter.create<tensor::EmptyOp>(op.getLoc(), resultType, dynSizes);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultType, ValueRange{operand, scale, zeropoint},
ValueRange{empty}, maps, iterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value operand = args[0];
Value scale = args[1];
Value zeropoint = args[2];
if (operandDTy.isUnsignedInteger(8)) {
operand = b.create<arith::ExtUIOp>(loc, b.getI32Type(), operand);
} else if (operandDTy.isSignedInteger(8)) {
operand = b.create<arith::ExtSIOp>(loc, b.getI32Type(), operand);
}
if (zeropointDTy.isUnsignedInteger(8)) {
zeropoint =
b.create<arith::ExtUIOp>(loc, b.getI32Type(), zeropoint);
} else if (zeropointDTy.isSignedInteger(8)) {
zeropoint =
b.create<arith::ExtSIOp>(loc, b.getI32Type(), zeropoint);
}
Value sub = rewriter.create<arith::SubIOp>(loc, operand, zeropoint);
Value fp =
rewriter.create<arith::SIToFPOp>(loc, args[3].getType(), sub);
Value mul = rewriter.create<arith::MulFOp>(loc, fp, scale);
b.create<linalg::YieldOp>(loc, mul);
});
rewriter.replaceOp(op, linalgOp.getResults());
return success();
}
};
} // namespace
namespace {
template <typename OpTy>
class ConvertCastEquivalentOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
LogicalResult
matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
converter->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType,
adaptor.getSelf());
return success();
@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
target.addIllegalOp<TensorStaticInfoCastOp>();
patterns.add<ConvertAtenIntReprOp>(typeConverter, context);
target.addIllegalOp<AtenIntReprOp>();
patterns.add<ConvertMakePerTensorQuantizedTensorOp>(typeConverter, context);
patterns.add<ConvertCastEquivalentOp<Aten_MakePerChannelQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerChannelQuantizedTensorOp>();
patterns.add<ConvertCastEquivalentOp<Aten_MakePerTensorQuantizedTensorOp>>(
typeConverter, context);
target.addIllegalOp<Aten_MakePerTensorQuantizedTensorOp>();
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
}

View File

@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n"
" return %arg4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int) -> !torch.int {\n"
" %int14 = torch.constant.int 14\n"
" %int12 = torch.constant.int 12\n"
" %int1 = torch.constant.int 1\n"
" %int13 = torch.constant.int 13\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %int13 : !torch.int\n"
" } else {\n"
" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int12 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int14 : !torch.int\n"
" }\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n"
" %int14 = torch.constant.int 14\n"
" %int12 = torch.constant.int 12\n"

View File

@ -39,6 +39,20 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape>
compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self,
const at::Tensor &scale,
const at::Tensor &zero_point,
int64_t axis) {
if (self.scalar_type() == at::kChar)
return {Shape(at::kQInt8, self.sizes().vec())};
if (self.scalar_type() == at::kByte)
return {Shape(at::kQUInt8, self.sizes().vec())};
if (self.scalar_type() == at::kInt)
return {Shape(at::kQInt32, self.sizes().vec())};
assert(false);
}
std::vector<torch::lazy::Shape> compute_shape__make_per_tensor_quantized_tensor(
const at::Tensor &self, double scale, int64_t zero_point) {
if (self.scalar_type() == at::kChar)
@ -75,6 +89,12 @@ std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
const at::Tensor &self, const at::Tensor &scales,
const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) {
return {Shape(dtype, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,

View File

@ -313,6 +313,7 @@ TORCHDYNAMO_XFAIL_SET = {
"GroupNormNoWeightAndBiasModule_basic",
# Dynamo does not support tracing quantized tensors
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"AtenMmQuint8_basic",

View File

@ -251,6 +251,9 @@ def atenclamp_max〡shape(self: List[int], max: float) -> List[int]:
def atenrsubScalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)
def atenquantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]:
return upstream_shape_functions.unary(self)
def atenquantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]:
return upstream_shape_functions.unary(self)
@ -263,6 +266,9 @@ def atendequantizetensor〡shape(qtensor: List[int]) -> List[int]:
def atenint_repr〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def aten_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]:
return upstream_shape_functions.unary(self)
def aten_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]:
return upstream_shape_functions.unary(self)
@ -4280,6 +4286,9 @@ def primscollapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int
return a_dtype
def atenquantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int:
return dtype
def atenquantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int:
return dtype
@ -4297,6 +4306,14 @@ def atenint_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int8
return torch.int32
def aten_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int:
self_rank, self_dtype = self_rank_dtype
if (self_dtype == torch.uint8):
return torch.quint8
if (self_dtype == torch.int8):
return torch.qint8
return torch.qint32
def aten_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int:
self_rank, self_dtype = self_rank_dtype
if (self_dtype == torch.uint8):

View File

@ -820,10 +820,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
# quantized ops
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)")
emit("aten::dequantize.self : (Tensor) -> (Tensor)")
emit("aten::dequantize.tensor : (Tensor) -> (Tensor)")
emit("aten::int_repr : (Tensor) -> (Tensor)")
emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)")
emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)")
# ==========================================================================

View File

@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseDequantizePerChannelModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int8, True),
([4], torch.int8, True),
([4], torch.float, True),
])
def forward(self, x, zeropoint, scale):
qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1)
qx = torch.dequantize(qx)
return qx
@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule())
def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-128, high=127).to(torch.int8),
tu.randint(4, low=-128, high=127).to(torch.int8),
tu.rand(4)
)
# ==============================================================================
class GluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()