diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 2df6f95c8..261b4df3b 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -251,7 +251,10 @@ public: OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, int64_t domainVersion) : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), - domainVersion(domainVersion) {} + domainVersion(domainVersion) { + // Onnx lowerings could produce other Onnx operations during the rewrite. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 058fee4da..afc14a95e 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -18,6 +18,8 @@ Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector cstInput); +Type getQTorchTypeFromTorchIntType(Type ty); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 06a9662c9..9550e982b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -690,7 +690,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); }); patterns.onOp( - "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d54c1e1b9..8227514b5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -99,6 +100,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); }); + patterns.onOp( + "QLinearConv", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if ((binder.tensorOperands(operands, 8) && + binder.tensorOperands(operands, 9)) || + binder.tensorResultType(resultType)) + return failure(); + Value a = operands[0]; + Value aScale = operands[1]; + Value aZp = operands[2]; + Value b = operands[3]; + Value bScale = operands[4]; + Value bZp = operands[5]; + Value cScale = operands[6]; + Value cZp = operands[7]; + Value c = operands.size() == 9 ? operands[8] : nullptr; + + auto check = [](Value v) { + auto vTy = v.getType().cast(); + return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); + }; + if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + !check(cScale) || !check(cScale)) + return rewriter.notifyMatchFailure( + binder.op, "not supported for non per-tensor quantization"); + + auto extract = [&rewriter, &binder](Value v) { + auto vTy = v.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(vTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + aZp = extract(aZp); + bZp = extract(bZp); + cZp = extract(cZp); + aScale = extract(aScale); + bScale = extract(bScale); + cScale = extract(cScale); + + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { + auto ty = v.getType().cast(); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp); + }; + + a = make(a, aScale, aZp); + b = make(b, bScale, bZp); + + auto cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getIntegerType(32, /*issigned=*/true)); + + // TODO(suderman): insert convolution operator. + llvm::SmallVector newOperands = {a, b}; + if (c) + newOperands.push_back(c); + + cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + + llvm::SmallVector newAttributes; + newAttributes.push_back( + rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv"))); + for (auto namedAttr : binder.op->getAttrDictionary()) { + if (namedAttr.getName().getValue().compare("name") == 0) + continue; + llvm::errs() << namedAttr.getName() << "\n"; + newAttributes.push_back(namedAttr); + } + + c = rewriter + .create(binder.getLoc(), cTy, newOperands, + newAttributes) + .getResult(0); + + Value outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), aScale, + bScale); + Value outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + c = rewriter.create( + binder.getLoc(), cTy, c, outScale, outZp); + cTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + c = rewriter.create(binder.getLoc(), cTy, + c); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(cTy.getDtype())))); + c = rewriter.create( + binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + c); + return success(); + }); patterns.onOp( "QLinearMatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -157,28 +269,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( bScale = extract(bScale); cScale = extract(cScale); - auto getQTy = - [&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType { - auto dt = ty.getDtype(); - Type newDt; - if (dt.isUnsignedInteger(8)) { - newDt = rewriter.getType(); - } else if (dt.isSignedInteger(8)) { - newDt = rewriter.getType(); - } else if (dt.isSignedInteger(32)) { - newDt = rewriter.getType(); - } else { - return nullptr; - } - - return rewriter.getType(ty.getOptionalSizes(), - newDt); - }; - - auto make = [&rewriter, &binder, &getQTy](Value v, Value scale, - Value zp) -> Value { + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { auto ty = v.getType().cast(); - auto newTy = getQTy(ty); + auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); }; @@ -214,7 +308,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = getQTy(resultType); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 8f5a2e67c..ef3da8b3b 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::torch; @@ -26,3 +27,23 @@ Value mlir::torch::onnx_c::createConstantIntList( Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstValue); } + +Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { + Torch::ValueTensorType tty = dyn_cast(ty); + if (!tty) + return nullptr; + + auto ctx = ty.getContext(); + Type dty = tty.getDtype(); + + if (dty.isUnsignedInteger(8)) + dty = Torch::QUInt8Type::get(ctx); + if (dty.isSignedInteger(8)) + dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(32)) + dty = Torch::QInt32Type::get(ctx); + + if (!dty) + return nullptr; + return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty); +} diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 4523febb9..3557b27a2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -653,7 +653,7 @@ public: op, "lhs and rhs of convolution must either be both int or fp"); } - if (inputZp && weightZp) { + if (inputZp && weightZp && !isa(bias.getType())) { auto biasDTy = bias.getType().cast().getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 15d5ec105..6bc8a8ba0 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -64,10 +64,6 @@ public: if (operands.size() < 3) return failure(); - Value bias = operands[2]; - if (bias.getDefiningOp()) - return failure(); - Value lhsScale; if (auto qLhs = operands[0].getDefiningOp()) @@ -82,11 +78,18 @@ public: return failure(); auto resultTy = cast(op.getType()); - auto biasTy = bias.getType().cast(); - auto biasETy = biasTy.getOptionalDtype(); - if (!biasETy || !isa(biasETy)) + if (!isa(resultTy.getDtype())) return failure(); + Value bias = operands[2]; + auto biasTy = bias.getType().dyn_cast(); + + if (biasTy) { + auto biasETy = biasTy.getOptionalDtype(); + if (!biasETy || !isa(biasETy)) + return failure(); + } + Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); @@ -95,19 +98,21 @@ public: rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); - auto newBiasTy = - rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); - Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); - bias = rewriter.create( - op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); - bias = rewriter.create( - op.getLoc(), - rewriter.getType( - biasTy.getOptionalSizes(), - rewriter.getIntegerType(32, IntegerType::Signed)), - bias); - operands[2] = bias; + if (biasTy) { + auto newBiasTy = + rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); + Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); + bias = rewriter.create( + op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); + bias = rewriter.create( + op.getLoc(), + rewriter.getType( + biasTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)), + bias); + operands[2] = bias; + } auto convTy = rewriter.getType( resultTy.getOptionalSizes(), diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9d947dce5..ae36661bd 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -47,6 +47,83 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch // ----- +// CHECK-LABEL: @test_qlinearconv_nobias +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearconv_bias +func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + // CHECK-LABEL: @test_qlinearmatmul_2D func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index 1aaeb9ce1..f98cb842f 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,8 +28,8 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- -// CHECK-LABEL: @convolution -func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { +// CHECK-LABEL: @convolution_bias +func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 %false = torch.constant.bool false %zero = torch.constant.int 0 @@ -60,3 +60,38 @@ func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtens // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> } + + +// ----- + +// CHECK-LABEL: @convolution_nobias +func.func @convolution_nobias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>) -> !torch.vtensor<[1,3,7,7],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %none = torch.constant.none + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32> + %14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list + %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.convolution %7, %13, %none, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> + + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[NONE]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + return %16 : !torch.vtensor<[1,3,7,7],f32> +}