mirror of https://github.com/llvm/torch-mlir
[ONNX] Preliminary Work Towards Supporting QuantizedMLP_basic onnx e2e test (#3089)
See the related issues here: [SHARK-Turbine#556](https://github.com/nod-ai/SHARK-Turbine/issues/556) 1. Adds uint8 casting to onnx.Cast op 2. Fixes an issue with onnx.DequantizeLinear when the scale comes with shape [1]. 3. Adds support for unsigned types in an AtenItemOp folder 4. Adds a simpler quantized model for easier debugging 5. Adds a fusion pass to convert [quant -> dequant -> transpose -> mm] patterns to [transpose -> quant -> mm]. 6. Moved some xfails that are still not passing, but for different reasons than onnx.cast failures.pull/3093/head
parent
3c33dbd987
commit
532d297c46
|
@ -43,6 +43,8 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
|
|||
switch (dtypeIntOnnx) {
|
||||
case 1:
|
||||
return 6; // float
|
||||
case 2:
|
||||
return 0; // uint8
|
||||
case 3:
|
||||
return 1; // int8
|
||||
case 6:
|
||||
|
@ -1425,8 +1427,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
if (!resultType.hasDtype())
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"requires known result dtype");
|
||||
|
||||
if (scaleTy.getSizes().size() == 0) {
|
||||
if (scaleTy.getSizes().size() == 0 ||
|
||||
(scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) {
|
||||
Type qTy = operandTy.getDtype();
|
||||
|
||||
if (qTy.isUnsignedInteger(8)) {
|
||||
|
@ -1455,7 +1457,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"unimplemented: non-scalar scale");
|
||||
});
|
||||
patterns.onOp("Div", 7,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -97,7 +97,7 @@ public:
|
|||
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
||||
getZeroPoint(op.getMat2(), rhsZeroPoint);
|
||||
|
||||
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(lhsZeroPoint)) {
|
||||
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: aten.mm with mixed quantization");
|
||||
}
|
||||
|
|
|
@ -3798,7 +3798,9 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
|
|||
if (matchPattern(getOperand(), m_Constant(&attr))) {
|
||||
auto splat = attr.getSplatValue<Attribute>();
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(splat)) {
|
||||
return getI64IntegerAttr(getContext(), intAttr.getSInt());
|
||||
return intAttr.getType().isUnsignedInteger()
|
||||
? getI64IntegerAttr(getContext(), intAttr.getUInt())
|
||||
: getI64IntegerAttr(getContext(), intAttr.getSInt());
|
||||
}
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(splat)) {
|
||||
return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble());
|
||||
|
|
|
@ -54,6 +54,68 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp>
|
||||
class QuantizeTransposedOperands : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
llvm::SmallVector<Value> operands(op->getOperands());
|
||||
unsigned numOperands = operands.size();
|
||||
bool dequanted = false;
|
||||
for (unsigned i = 0; i < numOperands; i++) {
|
||||
if (auto trans = operands[i].getDefiningOp<AtenTransposeIntOp>()) {
|
||||
auto transOperands = trans.getOperands();
|
||||
Value dequantOperand;
|
||||
if (auto dequant =
|
||||
transOperands[0].getDefiningOp<AtenDequantizeSelfOp>()) {
|
||||
dequantOperand = dequant.getOperand();
|
||||
if (auto quant =
|
||||
dequantOperand
|
||||
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
|
||||
auto quantOperands = quant.getOperands();
|
||||
auto qType = quantOperands[0]
|
||||
.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getOptionalDtype();
|
||||
auto torchQType =
|
||||
quant.getType().cast<ValueTensorType>().getOptionalDtype();
|
||||
auto transQTy =
|
||||
rewriter.getType<ValueTensorType>(trans.getResult()
|
||||
.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getOptionalSizes(),
|
||||
qType);
|
||||
auto newQuantTy =
|
||||
rewriter.getType<ValueTensorType>(trans.getResult()
|
||||
.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getOptionalSizes(),
|
||||
torchQType);
|
||||
Value newTrans = rewriter.create<AtenTransposeIntOp>(
|
||||
op.getLoc(), transQTy, quantOperands[0], transOperands[1],
|
||||
transOperands[2]);
|
||||
Value newQuant =
|
||||
rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
|
||||
op.getLoc(), newQuantTy, newTrans, quantOperands[1],
|
||||
quantOperands[2]);
|
||||
operands[i] = newQuant;
|
||||
dequanted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!dequanted) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "no dequantized transpose inputs found.");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcOp> class QuantizeBias : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
@ -217,13 +279,15 @@ public:
|
|||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns
|
||||
.insert<RemoveUnused<AtenDequantizeSelfOp>,
|
||||
RemoveUnused<AtenDequantizeTensorOp>,
|
||||
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||
QuantizeOperands<AtenConvolutionOp>, QuantizeOperands<AtenMmOp>,
|
||||
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||
context);
|
||||
patterns.insert<
|
||||
RemoveUnused<AtenDequantizeSelfOp>,
|
||||
RemoveUnused<AtenDequantizeTensorOp>,
|
||||
RemoveUnused<AtenQuantizePerTensorOp>,
|
||||
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
||||
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
|
||||
QuantizeOperands<AtenMmOp>, QuantizeTransposedOperands<AtenMmOp>,
|
||||
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
|
||||
context);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
|
|
|
@ -149,6 +149,7 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
'AtenFloatScalarModule_basic',
|
||||
'AtenIntBoolOpModule_basic',
|
||||
'QuantizedMLP_basic',
|
||||
'QuantizedSingleLayer_basic',
|
||||
'ScalarImplicitFloatModule_basic',
|
||||
'ScalarImplicitIntModule_basic',
|
||||
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
|
@ -1412,6 +1413,7 @@ LTC_XFAIL_SET = {
|
|||
"NeFloatIntModule_basic",
|
||||
"NeIntModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"ScalarImplicitFloatModule_basic",
|
||||
"ScalarImplicitIntModule_basic",
|
||||
"SliceEndSleStartModule_basic",
|
||||
|
@ -1911,11 +1913,6 @@ ONNX_XFAIL_SET = {
|
|||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Cast
|
||||
"BucketizeTensorOutInt32RightModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"QuantizedMLP_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Clip
|
||||
"NormalizeModule_basic",
|
||||
|
||||
|
@ -2054,12 +2051,20 @@ ONNX_XFAIL_SET = {
|
|||
|
||||
# Failure - incorrect dtype
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
||||
# Failure - torch.aten.view lower
|
||||
"ViewSizeDimFollowedByExpandedOnesModule_basic",
|
||||
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
||||
"ViewSizeDimLedByExpandedOnesModule_basic",
|
||||
|
||||
# Failure - torch.aten.mm lower (mixed signedness of qtypes)
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
|
||||
# Failure - torch.aten.squeeze lower
|
||||
"BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank
|
||||
|
||||
# Failure - unknown
|
||||
"BucketizeTensorFloatModule_basic",
|
||||
"BucketizeTensorModule_basic",
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir._version import torch_version_for_comparison, version
|
|||
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
||||
"NativeGroupNormBackwardModule_basic",
|
||||
"QuantizedMLP_basic",
|
||||
"QuantizedSingleLayer_basic",
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
|
|
|
@ -12,6 +12,28 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class QuantizedSingleLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.random.manual_seed(0)
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(16, 8),
|
||||
)
|
||||
self.quantize = torch.quantization.QuantStub()
|
||||
self.dequantize = torch.quantization.DeQuantStub()
|
||||
|
||||
@export
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 16], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.quantize(x)
|
||||
x = self.layers(x)
|
||||
x = self.dequantize(x)
|
||||
return x
|
||||
|
||||
|
||||
class QuantizedMLP(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -53,6 +75,20 @@ def get_quantized_mlp():
|
|||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
def get_quantized_single_layer():
|
||||
model = QuantizedSingleLayer()
|
||||
model.eval()
|
||||
model.qconfig = torch.quantization.default_qconfig
|
||||
torch.quantization.prepare(model, inplace=True)
|
||||
torch.manual_seed(0)
|
||||
for _ in range(32):
|
||||
model(get_mlp_input())
|
||||
torch.quantization.convert(model, inplace=True)
|
||||
return model
|
||||
|
||||
@register_test_case(module_factory=get_quantized_single_layer)
|
||||
def QuantizedSingleLayer_basic(module, tu: TestUtils):
|
||||
module.forward(get_mlp_input())
|
||||
|
||||
@register_test_case(module_factory=get_quantized_mlp)
|
||||
def QuantizedMLP_basic(module, tu: TestUtils):
|
||||
|
|
Loading…
Reference in New Issue