[Torch]Support conv_transpose1d and conv_transpose3d (#3286)

1. Support conv_transpose1d and conv_transpose3d
2. Fix bugs of convertTransposedConv func in
lib/Conversion/TorchToStablehlo/Linear.cpp
pull/3413/head
Xinyu Yang 2024-06-03 15:11:12 +08:00 committed by GitHub
parent 617b00b983
commit 23b53050de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 241 additions and 5 deletions

View File

@ -591,25 +591,32 @@ public:
auto weightShape = weightTy.getShape();
auto nDims = inputTy.getRank();
auto weightDims = weightTy.getRank();
auto kernelDims = weightDims - 2;
auto nSpatialDims = nDims - 2;
auto convOutTy = outType;
// Transpose weight
SmallVector<int64_t> perm(nDims);
SmallVector<int64_t> transposeShape(nDims);
for (int i = 0; i < nDims; i++) {
if (i < 2)
perm[i] = nDims - 2 + i;
// 1d: kernelDims = 1, [0, 1, 2] => [2, 1, 0]
// 2d: kernelDims = 2, [0, 1, 2, 3] => [2, 3, 1, 0]
// 3d: kernelDims = 3, [0, 1, 2, 3, 4] => [2, 3, 4, 1, 0]
for (int i = 0; i < weightDims; i++) {
if (i < kernelDims)
perm[i] = 2 + i;
else
perm[i] = nDims - i - 1;
perm[i] = kernelDims + 1 - i;
transposeShape[i] = weightShape[perm[i]];
}
auto reverseDim = llvm::to_vector<4>(llvm::seq<int64_t>(0, kernelDims));
auto transposeTy =
RankedTensorType::get(transposeShape, weightTy.getElementType());
auto transposeOp = rewriter.create<stablehlo::TransposeOp>(
op->getLoc(), transposeTy, weight, perm);
auto reverseOp = rewriter.create<stablehlo::ReverseOp>(
op->getLoc(), transposeOp, ArrayRef<int64_t>{0, 1});
op->getLoc(), transposeOp, reverseDim);
// Prepare for transposed convolution
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);

View File

@ -9110,6 +9110,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose3d.input\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -11797,10 +11807,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose3d.input\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -3633,6 +3633,25 @@ public:
};
} // namespace
// Decompose aten.conv_transpose1d to aten.convolution
namespace {
class DecomposeAtenConvTranspose1dOp
: public OpRewritePattern<AtenConvTranspose1dOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvTranspose1dOp op,
PatternRewriter &rewriter) const override {
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(),
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
return success();
}
};
} // namespace
// Decompose aten.conv_transpose2d to aten.convolution
namespace {
class DecomposeAtenConvTranspose2dOp
@ -3652,6 +3671,25 @@ public:
};
} // namespace
// Decompose aten.conv_transpose3d to aten.convolution
namespace {
class DecomposeAtenConvTranspose3dOp
: public OpRewritePattern<AtenConvTranspose3dInputOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op,
PatternRewriter &rewriter) const override {
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(),
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
return success();
}
};
} // namespace
// The convolution backward op is decomposed as follows:
// inputH, inputW = input.shape[2:]
// output_padding_ = [
@ -7963,7 +8001,9 @@ public:
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose3dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);

View File

@ -428,7 +428,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenConv1dOp>();
target.addIllegalOp<AtenConv2dOp>();
target.addIllegalOp<AtenConv3dOp>();
target.addIllegalOp<AtenConvTranspose1dOp>();
target.addIllegalOp<AtenConvTranspose2dInputOp>();
target.addIllegalOp<AtenConvTranspose3dInputOp>();
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenLinspaceOp>();

View File

@ -911,6 +911,9 @@ STABLEHLO_PASS_SET = {
"Convolution2DStaticModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dStaticModule_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
@ -2662,6 +2665,8 @@ ONNX_XFAIL_SET = {
"PrimsIotaModule_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv_Transpose1dModule_basic",
"Conv_Transpose3dModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"CopyWithDifferentDTypesAndSizesModule_basic",
"CopyWithDifferentDTypesModule_basic",

View File

@ -1548,6 +1548,12 @@ def atenconvolution〡shape(input: List[int], weight: List[int], bias: Option
def atenconv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]:
return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1)
def atenconv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]:
return atenconvolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups)
def atenconv_transpose3dinput〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> List[int]:
return atenconvolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups)
def aten_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
return atenconvolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
@ -3538,6 +3544,10 @@ def atenconv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype:
input_rank, input_dtype = input_rank_dtype
return input_dtype
def atenconv_transpose1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) +
[Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)),
@ -3549,6 +3559,10 @@ def atenconv_transpose2dinput〡dtype(input_rank_dtype: Tuple[int, int], w
input_rank, input_dtype = input_rank_dtype
return input_dtype
def atenconv_transpose3dinput〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype
convolution_kwargs = {
"stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1}
@check_dtype_function(

View File

@ -760,6 +760,66 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils
module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3))
class Conv_Transpose1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose1d(
inputVec,
weight,
bias=None,
stride=[2],
padding=[1],
dilation=[1],
output_padding=[0],
groups=1,
)
@register_test_case(module_factory=lambda: Conv_Transpose1dModule())
def Conv_Transpose1dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2))
class Conv_Transpose1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 2, 6], torch.float32, True),
([2, 5, 2], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose1d(
inputVec,
weight,
bias=None,
stride=[2],
padding=[1],
dilation=[1],
output_padding=[0],
groups=1,
)
@register_test_case(module_factory=lambda: Conv_Transpose1dStaticModule())
def Conv_Transpose1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2))
class Conv_Transpose2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -790,6 +850,96 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
class Conv_Transpose2dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 2, 5, 6], torch.float32, True),
([2, 5, 2, 2], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose2d(
inputVec,
weight,
bias=None,
stride=[2, 2],
padding=[1, 1],
dilation=[1, 1],
output_padding=[0, 0],
groups=1,
)
@register_test_case(module_factory=lambda: Conv_Transpose2dStaticModule())
def Conv_Transpose2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
class Conv_Transpose3dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose3d(
inputVec,
weight,
bias=None,
stride=[2, 2, 2],
padding=[1, 1, 1],
dilation=[1, 1, 1],
output_padding=[0, 0, 0],
groups=1,
)
@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2))
class Conv_Transpose3dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([5, 2, 5, 6, 7], torch.float32, True),
([2, 5, 2, 2, 2], torch.float32, True),
]
)
def forward(self, inputVec, weight):
return torch.ops.aten.conv_transpose3d(
inputVec,
weight,
bias=None,
stride=[2, 2, 2],
padding=[1, 1, 1],
dilation=[1, 1, 1],
output_padding=[0, 0, 0],
groups=1,
)
@register_test_case(module_factory=lambda: Conv_Transpose3dStaticModule())
def Conv_Transpose3dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2))
class UpSampleNearest2d(torch.nn.Module):
def __init__(self):
super().__init__()