mirror of https://github.com/llvm/torch-mlir
[ONNX] Conv op adds support for asymmetric padding. (#3426)
Supports asymmetric padding by performing a torch.nn.functional.pad on the input before performing the convolution. Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/3425/head
parent
94838ca44d
commit
1c2778dd56
|
@ -951,7 +951,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value input, weight;
|
Value input, weight;
|
||||||
int64_t group;
|
int64_t group;
|
||||||
|
@ -1034,24 +1033,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
|
|
||||||
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
||||||
cstOutputPadding;
|
cstOutputPadding;
|
||||||
|
Value paddedInput = input;
|
||||||
|
Value paddingList;
|
||||||
if (padding.size() != 2 * (rank - 2)) {
|
if (padding.size() != 2 * (rank - 2)) {
|
||||||
for (int64_t i : padding) {
|
for (int64_t i : padding) {
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||||
}
|
}
|
||||||
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstPadding);
|
||||||
} else {
|
} else {
|
||||||
|
// ONNX offers pads in the format listing all starting dims, then all
|
||||||
|
// ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts
|
||||||
|
// only starting dims, e.g. {t, l}. However, we can support padding at
|
||||||
|
// the beginning and end of each dimension by first performing
|
||||||
|
// torch.nn.functional.pad on the input. But this requires the pad
|
||||||
|
// values to be rearranged since torch pad() takes pads in the order
|
||||||
|
// rightmost dim start and end, then next to last, and so on, e.g. {l,
|
||||||
|
// r, t, b}.
|
||||||
|
bool matchedPads = true;
|
||||||
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||||
if (padding[i] != padding[i + (padding.size() / 2)]) {
|
if (padding[i] != padding[i + (padding.size() / 2)]) {
|
||||||
// TODO: Add support for different padding values for the
|
matchedPads = false;
|
||||||
// beginning and ending along each spatial axis
|
break;
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op,
|
|
||||||
"unsupported conversion: padding values for the beginning "
|
|
||||||
"and ending along each spatial axis must be equal");
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (matchedPads) {
|
||||||
|
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||||
}
|
}
|
||||||
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
cstPadding);
|
||||||
|
} else {
|
||||||
|
SmallVector<Value> padsRearrange;
|
||||||
|
SmallVector<Value> inputPaddingList;
|
||||||
|
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||||
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||||
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||||
|
padding[(padding.size() / 2) + i])));
|
||||||
|
inputPaddingList.emplace_back(
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||||
|
}
|
||||||
|
// The conv op itself will have no padding since the actual padding
|
||||||
|
// is performed using the torch.pad preceding it.
|
||||||
|
paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
Torch::IntType::get(binder.op->getContext())),
|
||||||
|
inputPaddingList);
|
||||||
|
Value padsSizeList =
|
||||||
|
rewriter
|
||||||
|
.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(
|
||||||
|
rewriter.getType<Torch::IntType>()),
|
||||||
|
padsRearrange)
|
||||||
|
.getResult();
|
||||||
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||||
|
binder.getLoc(), rewriter.getStringAttr("constant"));
|
||||||
|
Value constantValue;
|
||||||
|
auto inputTensorType =
|
||||||
|
cast<Torch::ValueTensorType>(input.getType());
|
||||||
|
if (isa<IntegerType>(inputTensorType.getDtype()))
|
||||||
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
if (isa<FloatType>(inputTensorType.getDtype()))
|
||||||
|
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getF64FloatAttr(0.0f));
|
||||||
|
// Pad output shape must be computed explicitly from the pad values
|
||||||
|
SmallVector<int64_t> newInputShape(inputTensorType.getSizes());
|
||||||
|
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||||
|
newInputShape[2 + i] +=
|
||||||
|
padding[i] + padding[(padding.size() / 2) + i];
|
||||||
|
}
|
||||||
|
auto padTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
newInputShape, inputTensorType.getDtype());
|
||||||
|
paddedInput = rewriter.create<Torch::AtenPadOp>(
|
||||||
|
binder.getLoc(), padTy, input, padsSizeList, modeVal,
|
||||||
|
constantValue);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (int64_t i : dilations) {
|
for (int64_t i : dilations) {
|
||||||
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -1065,10 +1135,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
cstOutputPadding = {cstZero, cstZero};
|
cstOutputPadding = {cstZero, cstZero};
|
||||||
|
|
||||||
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
|
||||||
binder.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
|
||||||
cstPadding);
|
|
||||||
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
|
@ -1095,7 +1161,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
||||||
binder.op, resultType, input, weight, bias, stridesList,
|
binder.op, resultType, paddedInput, weight, bias, stridesList,
|
||||||
paddingList, dilationsList, transposed, outputPaddingList,
|
paddingList, dilationsList, transposed, outputPaddingList,
|
||||||
cstGroup);
|
cstGroup);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -946,12 +946,12 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>
|
||||||
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -969,12 +969,12 @@ func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32
|
||||||
func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
@ -992,12 +992,12 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>,
|
||||||
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
// CHECK: %[[C3_0:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
|
||||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
|
|
@ -60,12 +60,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1:
|
||||||
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
||||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
|
||||||
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
||||||
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
||||||
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
||||||
|
@ -99,12 +99,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t
|
||||||
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
// CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8>
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
||||||
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_2:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]]
|
|
||||||
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
// CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]]
|
||||||
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
// CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]]
|
||||||
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
// CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]]
|
||||||
|
|
Loading…
Reference in New Issue