[onnx] Add OnnxToTorch support for `onnx.ConvInteger` (#3179)

All e2e iree tests compiled, but they have the run issue of mismatch of
dtype like the following
```
expected:
1x1x2x2xsi32=[[[12 16][24 28]]]
actual:
1x1x2x2xi32=[[[12 16][24 28]]]
```
pull/3212/head
jinchen 2024-04-23 09:42:02 -07:00 committed by GitHub
parent db3842f2e8
commit ddb29c2c02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 213 additions and 0 deletions

View File

@ -981,6 +981,157 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"ConvInteger", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET")
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
Torch::ValueTensorType resultType;
Value input, weight, inputZp, weightZp;
int64_t group;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(weight, 1) ||
binder.s64IntegerAttr(group, "group", 1) ||
binder.tensorResultType(resultType))
return failure();
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
auto weightTy = dyn_cast<Torch::ValueTensorType>(weight.getType());
if (!weightTy || !weightTy.hasSizes())
return rewriter.notifyMatchFailure(
binder.op, "Expected weight type having sizes");
ArrayRef<int64_t> weightShape = weightTy.getSizes();
SmallVector<int64_t> kernelShape;
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
return failure();
if (kernelShape.size()) {
if (kernelShape.size() != weightShape.size() - 2) {
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: kernel_shape list size should have "
"number of values equal to weight_rank - 2");
} else {
for (unsigned i = 0; i < kernelShape.size(); i++) {
if (weightShape[i + 2] != kernelShape[i])
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: kernel_shape value "
"should be equal to the weight tensor shape");
}
}
}
// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;
SmallVector<int64_t> padding, strides, dilations;
SmallVector<int64_t> defaultPadding(rank - 2, 0),
defaultStrides(rank - 2, 1), defaultDilations(rank - 2, 1);
// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding))
return failure();
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2))
return rewriter.notifyMatchFailure(
binder.op, "padding list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(dilations, "dilations",
defaultDilations))
return failure();
if (dilations.size() != rank - 2)
return rewriter.notifyMatchFailure(
binder.op,
"dilations list size does not match the number of axes");
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides))
return failure();
if (strides.size() != rank - 2)
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
Value scale = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(1.0));
if (binder.tensorOperandAtIndex(inputZp, 2)) {
inputZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
} else {
inputZp = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), inputZp);
}
if (binder.tensorOperandAtIndex(weightZp, 3))
weightZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
// TODO: support per channel quantization if weightZp is a 1-D tensor
if (auto zpTy = dyn_cast<Torch::ValueTensorType>(weightZp.getType())) {
for (auto dim : zpTy.getSizes())
if (dim != 1)
return failure();
weightZp = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), weightZp);
}
SmallVector<Value> cstPadding;
if (padding.size() != 2 * (rank - 2)) {
for (int64_t i : padding) {
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
} else {
for (unsigned i = 0; i < padding.size() / 2; i++) {
if (padding[i] != padding[i + (padding.size() / 2)])
// TODO: Add support for different padding values for the
// beginning and ending along each spatial axis
return rewriter.notifyMatchFailure(
binder.op,
"unsupported conversion: padding values for the beginning "
"and ending along each spatial axis must be equal");
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
}
}
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
cstPadding);
Value dilationsList =
createConstantIntList(binder, rewriter, dilations);
Value stridesList = createConstantIntList(binder, rewriter, strides);
Value outputPaddingList =
createConstantIntList(binder, rewriter, {0, 0});
Value transposed =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(group));
Type inputQTy = getQTorchTypeFromTorchIntType(inputTy);
Type weightQTy = getQTorchTypeFromTorchIntType(weightTy);
input = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), inputQTy, input, scale, inputZp);
weight = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), weightQTy, weight, scale, weightZp);
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
binder.op, resultType, input, weight, bias, stridesList,
paddingList, dilationsList, transposed, outputPaddingList,
cstGroup);
return success();
});
patterns.onOp(
"ConvTranspose", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;

View File

@ -938,6 +938,68 @@ func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,22
// -----
// CHECK-LABEL: @test_convinteger_without_padding
func.func @test_convinteger_without_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[WEIGHT_ZP:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],ui8> -> !torch.int
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C1_3:.*]] = torch.constant.int 1
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8>
// CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8>
// CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,2,2],si32>
%none = torch.constant.none
%0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[1,1,2,2],si32>
return %0 : !torch.vtensor<[1,1,2,2],si32>
}
// -----
// CHECK-LABEL: @test_convinteger_with_padding
func.func @test_convinteger_with_padding(%arg0: !torch.vtensor<[1,1,3,3],ui8>, %arg1: !torch.vtensor<[1,1,2,2],ui8>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[INPUT_ZP:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int
// CHECK: %[[WEIGHT_ZP:.*]] = torch.constant.int 0
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C1_3:.*]] = torch.constant.int 1
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_2]], %[[C1_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C1_4:.*]] = torch.constant.int 1
// CHECK: %[[C1_5:.*]] = torch.constant.int 1
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_4]], %[[C1_5]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: %[[INPUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[INPUT_ZP]] : !torch.vtensor<[1,1,3,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,3,3],!torch.quint8>
// CHECK: %[[WEIGHT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[WEIGHT_ZP]] : !torch.vtensor<[1,1,2,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,2,2],!torch.quint8>
// CHECK: torch.aten.convolution %[[INPUT]], %[[WEIGHT]], %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],!torch.quint8>, !torch.vtensor<[1,1,2,2],!torch.quint8>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,4],si32>
%none = torch.constant.none
%0 = torch.operator "onnx.ConvInteger"(%arg0, %arg1, %arg2) {torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,3,3],ui8>, !torch.vtensor<[1,1,2,2],ui8>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,4,4],si32>
return %0 : !torch.vtensor<[1,1,4,4],si32>
}
// -----
// CHECK-LABEL: @test_convtranspose_dilations
func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0