From 532d297c46b13066da851af6eca53b91ff7a13c0 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Mon, 1 Apr 2024 18:21:05 -0500 Subject: [PATCH] [ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089) See the related issues here: [SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556) 1. Adds uint8 casting to onnx.Cast op 2. Fixes an issue with onnx.DequantizeLinear when the scale comes with shape [1]. 3. Adds support for unsigned types in an AtenItemOp folder 4. Adds a simpler quantized model for easier debugging 5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm] patterns to [transpose -> quant -> mm]. 6. Moved some xfails that are still not passing, but for different reasons than onnx.cast failures. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 9 ++- lib/Conversion/TorchToLinalg/Linear.cpp | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 4 +- .../Torch/Transforms/FuseQuantizedOps.cpp | 78 +++++++++++++++++-- projects/pt1/e2e_testing/xfail_sets.py | 15 ++-- .../test_suite/__init__.py | 1 + .../test_suite/quantized_models.py | 36 +++++++++ 7 files changed, 128 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 6197d04f9..5d4e693d0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -43,6 +43,8 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { switch (dtypeIntOnnx) { case 1: return 6; // float + case 2: + return 0; // uint8 case 3: return 1; // int8 case 6: @@ -1425,8 +1427,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (!resultType.hasDtype()) return rewriter.notifyMatchFailure(binder.op, "requires known result dtype"); - - if (scaleTy.getSizes().size() == 0) { + if (scaleTy.getSizes().size() == 0 || + (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { Type qTy = operandTy.getDtype(); if (qTy.isUnsignedInteger(8)) { @@ -1455,7 +1457,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); } - return failure(); + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: non-scalar scale"); }); patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 81b3e5d67..9f9e8c2fb 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -97,7 +97,7 @@ public: getZeroPoint(op.getSelf(), lhsZeroPoint); getZeroPoint(op.getMat2(), rhsZeroPoint); - if (static_cast(lhsZeroPoint) != static_cast(lhsZeroPoint)) { + if (static_cast(lhsZeroPoint) != static_cast(rhsZeroPoint)) { return rewriter.notifyMatchFailure( op, "unsupported: aten.mm with mixed quantization"); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bfb745f5c..ae457a7fe 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3798,7 +3798,9 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { if (matchPattern(getOperand(), m_Constant(&attr))) { auto splat = attr.getSplatValue(); if (auto intAttr = dyn_cast(splat)) { - return getI64IntegerAttr(getContext(), intAttr.getSInt()); + return intAttr.getType().isUnsignedInteger() + ? getI64IntegerAttr(getContext(), intAttr.getUInt()) + : getI64IntegerAttr(getContext(), intAttr.getSInt()); } if (auto floatAttr = dyn_cast(splat)) { return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 6bc8a8ba0..ce9d58a9c 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -54,6 +54,68 @@ public: } }; +template +class QuantizeTransposedOperands : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + + llvm::SmallVector operands(op->getOperands()); + unsigned numOperands = operands.size(); + bool dequanted = false; + for (unsigned i = 0; i < numOperands; i++) { + if (auto trans = operands[i].getDefiningOp()) { + auto transOperands = trans.getOperands(); + Value dequantOperand; + if (auto dequant = + transOperands[0].getDefiningOp()) { + dequantOperand = dequant.getOperand(); + if (auto quant = + dequantOperand + .getDefiningOp()) { + auto quantOperands = quant.getOperands(); + auto qType = quantOperands[0] + .getType() + .cast() + .getOptionalDtype(); + auto torchQType = + quant.getType().cast().getOptionalDtype(); + auto transQTy = + rewriter.getType(trans.getResult() + .getType() + .cast() + .getOptionalSizes(), + qType); + auto newQuantTy = + rewriter.getType(trans.getResult() + .getType() + .cast() + .getOptionalSizes(), + torchQType); + Value newTrans = rewriter.create( + op.getLoc(), transQTy, quantOperands[0], transOperands[1], + transOperands[2]); + Value newQuant = + rewriter.create( + op.getLoc(), newQuantTy, newTrans, quantOperands[1], + quantOperands[2]); + operands[i] = newQuant; + dequanted = true; + } + } + } + } + if (!dequanted) { + return rewriter.notifyMatchFailure( + op, "no dequantized transpose inputs found."); + } + rewriter.replaceOpWithNewOp(op, op.getType(), operands); + return success(); + } +}; + template class QuantizeBias : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -217,13 +279,15 @@ public: void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns - .insert, - RemoveUnused, - RemoveUnused, - QuantizeOperands, QuantizeOperands, - QuantizeAccumulator, QuantizeBias>( - context); + patterns.insert< + RemoveUnused, + RemoveUnused, + RemoveUnused, + RemoveUnused, + RemoveUnused, QuantizeOperands, + QuantizeOperands, QuantizeTransposedOperands, + QuantizeAccumulator, QuantizeBias>( + context); GreedyRewriteConfig config; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2868ac9d8..2e4dc0d09 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -149,6 +149,7 @@ TORCHDYNAMO_XFAIL_SET = { 'AtenFloatScalarModule_basic', 'AtenIntBoolOpModule_basic', 'QuantizedMLP_basic', + 'QuantizedSingleLayer_basic', 'ScalarImplicitFloatModule_basic', 'ScalarImplicitIntModule_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default @@ -1412,6 +1413,7 @@ LTC_XFAIL_SET = { "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", + "QuantizedSingleLayer_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", "SliceEndSleStartModule_basic", @@ -1911,11 +1913,6 @@ ONNX_XFAIL_SET = { "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AvgPool2dDivisorOverrideModule_basic", - # Failure - onnx_lowering: onnx.Cast - "BucketizeTensorOutInt32RightModule_basic", - "ElementwiseToDtypeI64ToUI8Module_basic", - "QuantizedMLP_basic", - # Failure - onnx_lowering: onnx.Clip "NormalizeModule_basic", @@ -2054,12 +2051,20 @@ ONNX_XFAIL_SET = { # Failure - incorrect dtype "ReduceMaxAlongDimUnsignedInt_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", # Failure - torch.aten.view lower "ViewSizeDimFollowedByExpandedOnesModule_basic", "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", "ViewSizeDimLedByExpandedOnesModule_basic", + # Failure - torch.aten.mm lower (mixed signedness of qtypes) + "QuantizedMLP_basic", + "QuantizedSingleLayer_basic", + + # Failure - torch.aten.squeeze lower + "BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank + # Failure - unknown "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index c4d21ea08..1dea4cbe2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -12,6 +12,7 @@ from torch_mlir._version import torch_version_for_comparison, version COMMON_TORCH_MLIR_LOWERING_XFAILS = { "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", + "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index e4a118700..262578b1e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -12,6 +12,28 @@ from torch_mlir_e2e_test.annotations import annotate_args, export # ============================================================================== +class QuantizedSingleLayer(nn.Module): + def __init__(self): + super().__init__() + torch.random.manual_seed(0) + self.layers = nn.Sequential( + nn.Linear(16, 8), + ) + self.quantize = torch.quantization.QuantStub() + self.dequantize = torch.quantization.DeQuantStub() + + @export + @export + @annotate_args([ + None, + ([1, 16], torch.float32, True), + ]) + def forward(self, x): + x = self.quantize(x) + x = self.layers(x) + x = self.dequantize(x) + return x + class QuantizedMLP(nn.Module): def __init__(self): @@ -53,6 +75,20 @@ def get_quantized_mlp(): torch.quantization.convert(model, inplace=True) return model +def get_quantized_single_layer(): + model = QuantizedSingleLayer() + model.eval() + model.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(model, inplace=True) + torch.manual_seed(0) + for _ in range(32): + model(get_mlp_input()) + torch.quantization.convert(model, inplace=True) + return model + +@register_test_case(module_factory=get_quantized_single_layer) +def QuantizedSingleLayer_basic(module, tu: TestUtils): + module.forward(get_mlp_input()) @register_test_case(module_factory=get_quantized_mlp) def QuantizedMLP_basic(module, tu: TestUtils):