mirror of https://github.com/llvm/torch-mlir
[onnx] Convert `onnx.QLinearConv` to `torch` (#2851)
Leaning on the QDQ functionality in torch we can support the QLinearConv operation by piggybacking through `torch.Convolution`. This includes some changes such as allowing the `onnx` rewriter to run recursively. Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which is then lowered to `torch`.pull/2870/head
parent
cb52c4b3cc
commit
e3faef5224
|
@ -251,7 +251,10 @@ public:
|
|||
OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix,
|
||||
int64_t domainVersion)
|
||||
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)),
|
||||
domainVersion(domainVersion) {}
|
||||
domainVersion(domainVersion) {
|
||||
// Onnx lowerings could produce other Onnx operations during the rewrite.
|
||||
setHasBoundedRewriteRecursion();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
|
||||
|
|
|
@ -18,6 +18,8 @@ Value createConstantIntList(OpBinder binder,
|
|||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<int64_t> cstInput);
|
||||
|
||||
Type getQTorchTypeFromTorchIntType(Type ty);
|
||||
|
||||
} // namespace mlir::torch::onnx_c
|
||||
|
||||
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H
|
||||
|
|
|
@ -690,7 +690,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return failure();
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
@ -99,6 +100,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"QLinearConv", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
llvm::SmallVector<Value> operands;
|
||||
if ((binder.tensorOperands(operands, 8) &&
|
||||
binder.tensorOperands(operands, 9)) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value a = operands[0];
|
||||
Value aScale = operands[1];
|
||||
Value aZp = operands[2];
|
||||
Value b = operands[3];
|
||||
Value bScale = operands[4];
|
||||
Value bZp = operands[5];
|
||||
Value cScale = operands[6];
|
||||
Value cZp = operands[7];
|
||||
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
||||
|
||||
auto check = [](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
||||
};
|
||||
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
||||
!check(cScale) || !check(cScale))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "not supported for non per-tensor quantization");
|
||||
|
||||
auto extract = [&rewriter, &binder](Value v) {
|
||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||
if (isa<IntegerType>(vTy.getDtype()))
|
||||
extractTy = rewriter.getType<Torch::IntType>();
|
||||
|
||||
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
||||
v);
|
||||
};
|
||||
|
||||
aZp = extract(aZp);
|
||||
bZp = extract(bZp);
|
||||
cZp = extract(cZp);
|
||||
aScale = extract(aScale);
|
||||
bScale = extract(bScale);
|
||||
cScale = extract(cScale);
|
||||
|
||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||
Value zp) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), newTy, v, scale, zp);
|
||||
};
|
||||
|
||||
a = make(a, aScale, aZp);
|
||||
b = make(b, bScale, bZp);
|
||||
|
||||
auto cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(),
|
||||
rewriter.getIntegerType(32, /*issigned=*/true));
|
||||
|
||||
// TODO(suderman): insert convolution operator.
|
||||
llvm::SmallVector<Value> newOperands = {a, b};
|
||||
if (c)
|
||||
newOperands.push_back(c);
|
||||
|
||||
cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(),
|
||||
rewriter.getType<Torch::QInt32Type>());
|
||||
|
||||
llvm::SmallVector<NamedAttribute> newAttributes;
|
||||
newAttributes.push_back(
|
||||
rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv")));
|
||||
for (auto namedAttr : binder.op->getAttrDictionary()) {
|
||||
if (namedAttr.getName().getValue().compare("name") == 0)
|
||||
continue;
|
||||
llvm::errs() << namedAttr.getName() << "\n";
|
||||
newAttributes.push_back(namedAttr);
|
||||
}
|
||||
|
||||
c = rewriter
|
||||
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
|
||||
newAttributes)
|
||||
.getResult(0);
|
||||
|
||||
Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale,
|
||||
bScale);
|
||||
Value outZp = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
c = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), cTy, c, outScale, outZp);
|
||||
cTy = rewriter.getType<Torch::ValueTensorType>(
|
||||
resultType.getOptionalSizes(), rewriter.getF32Type());
|
||||
|
||||
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
||||
c);
|
||||
cTy = dyn_cast<Torch::ValueTensorType>(
|
||||
getQTorchTypeFromTorchIntType(resultType));
|
||||
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64),
|
||||
static_cast<int64_t>(
|
||||
Torch::getScalarTypeForType(cTy.getDtype()))));
|
||||
c = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
||||
binder.getLoc(), cTy, c, cScale, cZp, dtyVal);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType,
|
||||
c);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"QLinearMatMul", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
@ -157,28 +269,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
bScale = extract(bScale);
|
||||
cScale = extract(cScale);
|
||||
|
||||
auto getQTy =
|
||||
[&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType {
|
||||
auto dt = ty.getDtype();
|
||||
Type newDt;
|
||||
if (dt.isUnsignedInteger(8)) {
|
||||
newDt = rewriter.getType<Torch::QUInt8Type>();
|
||||
} else if (dt.isSignedInteger(8)) {
|
||||
newDt = rewriter.getType<Torch::QInt8Type>();
|
||||
} else if (dt.isSignedInteger(32)) {
|
||||
newDt = rewriter.getType<Torch::QInt32Type>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return rewriter.getType<Torch::ValueTensorType>(ty.getOptionalSizes(),
|
||||
newDt);
|
||||
};
|
||||
|
||||
auto make = [&rewriter, &binder, &getQTy](Value v, Value scale,
|
||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||
Value zp) -> Value {
|
||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
||||
auto newTy = getQTy(ty);
|
||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||
binder.getLoc(), newTy, v, scale, zp);
|
||||
};
|
||||
|
@ -214,7 +308,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
|
||||
c = rewriter.create<Torch::AtenDequantizeSelfOp>(binder.getLoc(), cTy,
|
||||
c);
|
||||
cTy = getQTy(resultType);
|
||||
cTy = dyn_cast<Torch::ValueTensorType>(
|
||||
getQTorchTypeFromTorchIntType(resultType));
|
||||
Value dtyVal = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -26,3 +27,23 @@ Value mlir::torch::onnx_c::createConstantIntList(
|
|||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstValue);
|
||||
}
|
||||
|
||||
Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
|
||||
Torch::ValueTensorType tty = dyn_cast<Torch::ValueTensorType>(ty);
|
||||
if (!tty)
|
||||
return nullptr;
|
||||
|
||||
auto ctx = ty.getContext();
|
||||
Type dty = tty.getDtype();
|
||||
|
||||
if (dty.isUnsignedInteger(8))
|
||||
dty = Torch::QUInt8Type::get(ctx);
|
||||
if (dty.isSignedInteger(8))
|
||||
dty = Torch::QInt8Type::get(ctx);
|
||||
if (dty.isSignedInteger(32))
|
||||
dty = Torch::QInt32Type::get(ctx);
|
||||
|
||||
if (!dty)
|
||||
return nullptr;
|
||||
return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty);
|
||||
}
|
||||
|
|
|
@ -653,7 +653,7 @@ public:
|
|||
op, "lhs and rhs of convolution must either be both int or fp");
|
||||
}
|
||||
|
||||
if (inputZp && weightZp) {
|
||||
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||
auto biasDTy = bias.getType().cast<RankedTensorType>().getElementType();
|
||||
if (!biasDTy.isInteger(32)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
|
@ -64,10 +64,6 @@ public:
|
|||
if (operands.size() < 3)
|
||||
return failure();
|
||||
|
||||
Value bias = operands[2];
|
||||
if (bias.getDefiningOp<AtenDequantizeTensorOp>())
|
||||
return failure();
|
||||
|
||||
Value lhsScale;
|
||||
if (auto qLhs =
|
||||
operands[0].getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>())
|
||||
|
@ -82,10 +78,17 @@ public:
|
|||
return failure();
|
||||
|
||||
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||
auto biasTy = bias.getType().cast<ValueTensorType>();
|
||||
if (!isa<mlir::FloatType>(resultTy.getDtype()))
|
||||
return failure();
|
||||
|
||||
Value bias = operands[2];
|
||||
auto biasTy = bias.getType().dyn_cast<ValueTensorType>();
|
||||
|
||||
if (biasTy) {
|
||||
auto biasETy = biasTy.getOptionalDtype();
|
||||
if (!biasETy || !isa<mlir::FloatType>(biasETy))
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value biasScale = rewriter.create<AtenMulFloatOp>(
|
||||
op.getLoc(), lhsScale.getType(), lhsScale, rhsScale);
|
||||
|
@ -95,6 +98,8 @@ public:
|
|||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
||||
auto qi32Ty = rewriter.getType<QInt32Type>();
|
||||
|
||||
if (biasTy) {
|
||||
auto newBiasTy =
|
||||
rewriter.getType<ValueTensorType>(biasTy.getOptionalSizes(), qi32Ty);
|
||||
Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty);
|
||||
|
@ -106,8 +111,8 @@ public:
|
|||
biasTy.getOptionalSizes(),
|
||||
rewriter.getIntegerType(32, IntegerType::Signed)),
|
||||
bias);
|
||||
|
||||
operands[2] = bias;
|
||||
}
|
||||
|
||||
auto convTy = rewriter.getType<ValueTensorType>(
|
||||
resultTy.getOptionalSizes(),
|
||||
|
|
|
@ -47,6 +47,83 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_qlinearconv_nobias
|
||||
func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8>
|
||||
// CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int
|
||||
// CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
|
||||
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
||||
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
||||
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
||||
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
|
||||
// CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[INT0_6:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
|
||||
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32>
|
||||
// CHECK: %[[INT13:.+]] = torch.constant.int 13
|
||||
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
|
||||
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8>
|
||||
// CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8>
|
||||
return %0 : !torch.vtensor<[1,1,7,7],ui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_qlinearconv_bias
|
||||
func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8>
|
||||
// CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int
|
||||
// CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int
|
||||
// CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float
|
||||
// CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
|
||||
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
||||
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
||||
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
||||
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[INT1_5:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
|
||||
// CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float
|
||||
// CHECK: %[[INT0_6:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32>
|
||||
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32>
|
||||
// CHECK: %[[INT13:.+]] = torch.constant.int 13
|
||||
// CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8>
|
||||
// CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8>
|
||||
// CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8>
|
||||
return %0 : !torch.vtensor<[1,1,7,7],ui8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_qlinearmatmul_2D
|
||||
func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8>
|
||||
|
|
|
@ -28,8 +28,8 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convolution
|
||||
func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||
// CHECK-LABEL: @convolution_bias
|
||||
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||
%scale = torch.constant.float 0.5
|
||||
%false = torch.constant.bool false
|
||||
%zero = torch.constant.int 0
|
||||
|
@ -60,3 +60,38 @@ func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtens
|
|||
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
|
||||
return %16 : !torch.vtensor<[1,3,7,7],f32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convolution_nobias
|
||||
func.func @convolution_nobias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||
%scale = torch.constant.float 0.5
|
||||
%false = torch.constant.bool false
|
||||
%zero = torch.constant.int 0
|
||||
%one = torch.constant.int 1
|
||||
%zp = torch.constant.int -128
|
||||
%none = torch.constant.none
|
||||
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32>
|
||||
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32>
|
||||
%14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%16 = torch.aten.convolution %7, %13, %none, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],f32>
|
||||
|
||||
// CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01
|
||||
// CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8>
|
||||
// CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8>
|
||||
// CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[NONE]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,3,7,7],si32>
|
||||
// CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32>
|
||||
// CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32>
|
||||
return %16 : !torch.vtensor<[1,3,7,7],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue