[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)

See the related issues here:
[SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556)

1. Adds uint8 casting to onnx.Cast op
2. Fixes an issue with onnx.DequantizeLinear when the scale comes with
shape [1].
3. Adds support for unsigned types in an AtenItemOp folder
4. Adds a simpler quantized model for easier debugging
5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm]
patterns to [transpose -> quant -> mm].
6. Moved some xfails that are still not passing, but for different
reasons than onnx.cast failures.
pull/3093/head
zjgarvey 2024-04-01 18:21:05 -05:00 committed by GitHub
parent 3c33dbd987
commit 532d297c46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 17 deletions

View File

@ -43,6 +43,8 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 2:
return 0; // uint8
case 3:
return 1; // int8
case 6:
@ -1425,8 +1427,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
if (!resultType.hasDtype())
return rewriter.notifyMatchFailure(binder.op,
"requires known result dtype");
if (scaleTy.getSizes().size() == 0) {
if (scaleTy.getSizes().size() == 0 ||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
Type qTy = operandTy.getDtype();
if (qTy.isUnsignedInteger(8)) {
@ -1455,7 +1457,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
}
return failure();
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: non-scalar scale");
});
patterns.onOp("Div", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -97,7 +97,7 @@ public:
getZeroPoint(op.getSelf(), lhsZeroPoint);
getZeroPoint(op.getMat2(), rhsZeroPoint);
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with mixed quantization");
}

View File

@ -3798,7 +3798,9 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
if (matchPattern(getOperand(), m_Constant(&attr))) {
auto splat = attr.getSplatValue<Attribute>();
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
return getI64IntegerAttr(getContext(), intAttr.getSInt());
return intAttr.getType().isUnsignedInteger()
? getI64IntegerAttr(getContext(), intAttr.getUInt())
: getI64IntegerAttr(getContext(), intAttr.getSInt());
}
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());

View File

@ -54,6 +54,68 @@ public:
}
};
template <typename SrcOp>
class QuantizeTransposedOperands : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
unsigned numOperands = operands.size();
bool dequanted = false;
for (unsigned i = 0; i < numOperands; i++) {
if (auto trans = operands[i].getDefiningOp<AtenTransposeIntOp>()) {
auto transOperands = trans.getOperands();
Value dequantOperand;
if (auto dequant =
transOperands[0].getDefiningOp<AtenDequantizeSelfOp>()) {
dequantOperand = dequant.getOperand();
if (auto quant =
dequantOperand
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
auto quantOperands = quant.getOperands();
auto qType = quantOperands[0]
.getType()
.cast<ValueTensorType>()
.getOptionalDtype();
auto torchQType =
quant.getType().cast<ValueTensorType>().getOptionalDtype();
auto transQTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
qType);
auto newQuantTy =
rewriter.getType<ValueTensorType>(trans.getResult()
.getType()
.cast<ValueTensorType>()
.getOptionalSizes(),
torchQType);
Value newTrans = rewriter.create<AtenTransposeIntOp>(
op.getLoc(), transQTy, quantOperands[0], transOperands[1],
transOperands[2]);
Value newQuant =
rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), newQuantTy, newTrans, quantOperands[1],
quantOperands[2]);
operands[i] = newQuant;
dequanted = true;
}
}
}
}
if (!dequanted) {
return rewriter.notifyMatchFailure(
op, "no dequantized transpose inputs found.");
}
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
return success();
}
};
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
@ -217,13 +279,15 @@ public:
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns
.insert<RemoveUnused<AtenDequantizeSelfOp>,
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);
patterns.insert<
RemoveUnused<AtenDequantizeSelfOp>,
RemoveUnused<AtenDequantizeTensorOp>,
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMmOp>, QuantizeTransposedOperands<AtenMmOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);
GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),

View File

@ -149,6 +149,7 @@ TORCHDYNAMO_XFAIL_SET = {
'AtenFloatScalarModule_basic',
'AtenIntBoolOpModule_basic',
'QuantizedMLP_basic',
'QuantizedSingleLayer_basic',
'ScalarImplicitFloatModule_basic',
'ScalarImplicitIntModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
@ -1412,6 +1413,7 @@ LTC_XFAIL_SET = {
"NeFloatIntModule_basic",
"NeIntModule_basic",
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
@ -1911,11 +1913,6 @@ ONNX_XFAIL_SET = {
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AvgPool2dDivisorOverrideModule_basic",
# Failure - onnx_lowering: onnx.Cast
"BucketizeTensorOutInt32RightModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"QuantizedMLP_basic",
# Failure - onnx_lowering: onnx.Clip
"NormalizeModule_basic",
@ -2054,12 +2051,20 @@ ONNX_XFAIL_SET = {
# Failure - incorrect dtype
"ReduceMaxAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
# Failure - torch.aten.view lower
"ViewSizeDimFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedByExpandedOnesModule_basic",
# Failure - torch.aten.mm lower (mixed signedness of qtypes)
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
# Failure - torch.aten.squeeze lower
"BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank
# Failure - unknown
"BucketizeTensorFloatModule_basic",
"BucketizeTensorModule_basic",

View File

@ -12,6 +12,7 @@ from torch_mlir._version import torch_version_for_comparison, version
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",

View File

@ -12,6 +12,28 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class QuantizedSingleLayer(nn.Module):
def __init__(self):
super().__init__()
torch.random.manual_seed(0)
self.layers = nn.Sequential(
nn.Linear(16, 8),
)
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()
@export
@export
@annotate_args([
None,
([1, 16], torch.float32, True),
])
def forward(self, x):
x = self.quantize(x)
x = self.layers(x)
x = self.dequantize(x)
return x
class QuantizedMLP(nn.Module):
def __init__(self):
@ -53,6 +75,20 @@ def get_quantized_mlp():
torch.quantization.convert(model, inplace=True)
return model
def get_quantized_single_layer():
model = QuantizedSingleLayer()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_mlp_input())
torch.quantization.convert(model, inplace=True)
return model
@register_test_case(module_factory=get_quantized_single_layer)
def QuantizedSingleLayer_basic(module, tu: TestUtils):
module.forward(get_mlp_input())
@register_test_case(module_factory=get_quantized_mlp)
def QuantizedMLP_basic(module, tu: TestUtils):