mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Conv and ConvTranspose op.
This commit adds the OnnxToTorch support for Conv and ConvTranspose op. Signed-Off By: vivekkhandelwal1424@gmail.compull/2685/head snapshot-20231221.1059
parent
d75cff6cd1
commit
3226241521
|
@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
}
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return failure();
|
||||
if (autoPad != "NOTSET") {
|
||||
// TODO: Add support for `auto_pad` != "NOTSET"
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
}
|
||||
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, weight;
|
||||
int64_t group;
|
||||
if (binder.tensorOperands(input, weight) ||
|
||||
binder.s64IntegerAttr(group, "group", 1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected weight type having sizes");
|
||||
}
|
||||
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
|
||||
SmallVector<int64_t> kernelShape;
|
||||
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
|
||||
return failure();
|
||||
if (kernelShape.size()) {
|
||||
if (kernelShape.size() != weightShape.size() - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unsupported conversion: kernel_shape list size should have "
|
||||
"number of values equal to weight_rank - 2");
|
||||
} else {
|
||||
for (unsigned i = 0; i < kernelShape.size(); i++) {
|
||||
if (weightShape[i + 2] != kernelShape[i]) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: kernel_shape value "
|
||||
"should be equal to the weight tensor shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the rank of input tensor.
|
||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
unsigned rank = *maybeRank;
|
||||
|
||||
SmallVector<int64_t> padding, strides, dilations;
|
||||
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations;
|
||||
for (unsigned i = 0; i < rank - 2; i++) {
|
||||
defaultPadding.push_back(0);
|
||||
defaultStrides.push_back(1);
|
||||
defaultDilations.push_back(1);
|
||||
}
|
||||
// Padding for the beginning and ending along each spatial axis, it can
|
||||
// take any value greater than or equal to 0. The value represent the
|
||||
// number of pixels added to the beginning and end part of the
|
||||
// corresponding axis. pads format should be as follow [x1_begin,
|
||||
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
|
||||
// at the beginning of axis i and xi_end, the number of pixels added at
|
||||
// the end of axis i.
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
|
||||
return failure();
|
||||
}
|
||||
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding list size does not match the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(dilations, "dilations",
|
||||
defaultDilations)) {
|
||||
return failure();
|
||||
}
|
||||
if (dilations.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"dilations list size does not match the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
|
||||
return failure();
|
||||
}
|
||||
if (strides.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
}
|
||||
|
||||
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
||||
cstOutputPadding;
|
||||
if (padding.size() != 2 * (rank - 2)) {
|
||||
for (int64_t i : padding) {
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||
if (padding[i] != padding[i + (padding.size() / 2)]) {
|
||||
// TODO: Add support for different padding values for the
|
||||
// beginning and ending along each spatial axis
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unsupported conversion: padding values for the beginning "
|
||||
"and ending along each spatial axis must be equal");
|
||||
}
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||
}
|
||||
}
|
||||
for (int64_t i : dilations) {
|
||||
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
for (int64_t i : strides) {
|
||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||
cstOutputPadding = {cstZero, cstZero};
|
||||
|
||||
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstPadding);
|
||||
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstDilations);
|
||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstStrides);
|
||||
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstOutputPadding);
|
||||
Value transposed =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value bias;
|
||||
if (binder.op->getNumOperands() == 3) {
|
||||
if (binder.tensorOperandAtIndex(bias, 2)) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
}
|
||||
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
||||
binder.op, resultType, input, weight, bias, stridesList,
|
||||
paddingList, dilationsList, transposed, outputPaddingList,
|
||||
cstGroup);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"ConvTranspose", 11,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
std::string autoPad;
|
||||
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
|
||||
return failure();
|
||||
if (autoPad != "NOTSET") {
|
||||
// TODO: Add support for `auto_pad` != "NOTSET"
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: auto_pad != NOTSET");
|
||||
}
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {}))
|
||||
return failure();
|
||||
if (outputShape.size()) {
|
||||
// TODO: Add support for non-None output_shape value.
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unsupported conversion: output_shape should be absent");
|
||||
}
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input, weight;
|
||||
int64_t group;
|
||||
if (binder.tensorOperands(input, weight) ||
|
||||
binder.s64IntegerAttr(group, "group", 1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
|
||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Expected weight type having sizes");
|
||||
}
|
||||
ArrayRef<int64_t> weightShape = weightTensorType.getSizes();
|
||||
SmallVector<int64_t> kernelShape;
|
||||
if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}))
|
||||
return failure();
|
||||
if (kernelShape.size()) {
|
||||
if (kernelShape.size() != weightShape.size() - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unsupported conversion: kernel_shape list size should have "
|
||||
"number of values equal to weight_rank - 2");
|
||||
} else {
|
||||
for (unsigned i = 0; i < kernelShape.size(); i++) {
|
||||
if (weightShape[i + 2] != kernelShape[i]) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported conversion: kernel_shape value "
|
||||
"should be equal to the weight tensor shape");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the rank of input tensor.
|
||||
std::optional<unsigned> maybeRank = Torch::getTensorRank(input);
|
||||
if (!maybeRank)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unimplemented: unranked tensor");
|
||||
unsigned rank = *maybeRank;
|
||||
|
||||
SmallVector<int64_t> padding, strides, dilations, outputPadding;
|
||||
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
|
||||
for (unsigned i = 0; i < rank - 2; i++) {
|
||||
defaultPadding.push_back(0);
|
||||
defaultStrides.push_back(1);
|
||||
defaultDilations.push_back(1);
|
||||
defaultOutputPadding.push_back(0);
|
||||
}
|
||||
// Padding for the beginning and ending along each spatial axis, it can
|
||||
// take any value greater than or equal to 0. The value represent the
|
||||
// number of pixels added to the beginning and end part of the
|
||||
// corresponding axis. pads format should be as follow [x1_begin,
|
||||
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
|
||||
// at the beginning of axis i and xi_end, the number of pixels added at
|
||||
// the end of axis i.
|
||||
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
|
||||
return failure();
|
||||
}
|
||||
if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "padding list size does not match the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(dilations, "dilations",
|
||||
defaultDilations)) {
|
||||
return failure();
|
||||
}
|
||||
if (dilations.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"dilations list size does not match the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) {
|
||||
return failure();
|
||||
}
|
||||
if (strides.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "strides list size does not match the number of axes");
|
||||
}
|
||||
if (binder.s64IntegerArrayAttr(outputPadding, "output_padding",
|
||||
defaultOutputPadding)) {
|
||||
return failure();
|
||||
}
|
||||
if (outputPadding.size() != rank - 2) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"output_padding list size does not match the number of axes");
|
||||
}
|
||||
|
||||
SmallVector<Value> cstPadding, cstStrides, cstDilations,
|
||||
cstOutputPadding;
|
||||
if (padding.size() != 2 * (rank - 2)) {
|
||||
for (int64_t i : padding) {
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
} else {
|
||||
for (unsigned i = 0; i < padding.size() / 2; i++) {
|
||||
if (padding[i] != padding[i + (padding.size() / 2)]) {
|
||||
// TODO: Add support for different padding values for the
|
||||
// beginning and ending along each spatial axis
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op,
|
||||
"unsupported conversion: padding values for the beginning "
|
||||
"and ending along each spatial axis must be equal");
|
||||
}
|
||||
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||
}
|
||||
}
|
||||
for (int64_t i : dilations) {
|
||||
cstDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
for (int64_t i : strides) {
|
||||
cstStrides.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
for (int64_t i : outputPadding) {
|
||||
cstOutputPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
|
||||
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstPadding);
|
||||
Value dilationsList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstDilations);
|
||||
Value stridesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstStrides);
|
||||
Value outputPaddingList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.getLoc(),
|
||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
cstOutputPadding);
|
||||
Value transposed =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||
Value bias;
|
||||
if (binder.op->getNumOperands() == 3) {
|
||||
if (binder.tensorOperandAtIndex(bias, 2)) {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
bias = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
}
|
||||
Value cstGroup = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(group));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenConvolutionOp>(
|
||||
binder.op, resultType, input, weight, bias, stridesList,
|
||||
paddingList, dilationsList, transposed, outputPaddingList,
|
||||
cstGroup);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Cos", 7,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -462,3 +462,133 @@ func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>
|
|||
%0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32>
|
||||
return %0 : !torch.vtensor<[1,3,31,31,31],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_conv_with_strides_no_padding
|
||||
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,3,2],f32>
|
||||
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32>
|
||||
return %0 : !torch.vtensor<[1,1,3,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_conv_with_strides_padding
|
||||
func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
|
||||
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_convtranspose_dilations
|
||||
func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,5,5],f32>
|
||||
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.dilations = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,1,5,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_convtranspose
|
||||
func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,5,5],f32>
|
||||
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,2,5,5],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_convtranspose_pad
|
||||
func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,10,8],f32>
|
||||
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_padding = [1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32>
|
||||
return %0 : !torch.vtensor<[1,2,10,8],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_convtranspose_pads
|
||||
func.func @test_convtranspose_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C1_0:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,2,7,3],f32>
|
||||
%0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32>
|
||||
return %0 : !torch.vtensor<[1,2,7,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue