mirror of https://github.com/llvm/torch-mlir
[Torch]Support conv_transpose1d and conv_transpose3d (#3286)
1. Support conv_transpose1d and conv_transpose3d 2. Fix bugs of convertTransposedConv func in lib/Conversion/TorchToStablehlo/Linear.cpppull/3413/head
parent
617b00b983
commit
23b53050de
|
@ -591,25 +591,32 @@ public:
|
|||
auto weightShape = weightTy.getShape();
|
||||
|
||||
auto nDims = inputTy.getRank();
|
||||
auto weightDims = weightTy.getRank();
|
||||
auto kernelDims = weightDims - 2;
|
||||
|
||||
auto nSpatialDims = nDims - 2;
|
||||
auto convOutTy = outType;
|
||||
|
||||
// Transpose weight
|
||||
SmallVector<int64_t> perm(nDims);
|
||||
SmallVector<int64_t> transposeShape(nDims);
|
||||
for (int i = 0; i < nDims; i++) {
|
||||
if (i < 2)
|
||||
perm[i] = nDims - 2 + i;
|
||||
// 1d: kernelDims = 1, [0, 1, 2] => [2, 1, 0]
|
||||
// 2d: kernelDims = 2, [0, 1, 2, 3] => [2, 3, 1, 0]
|
||||
// 3d: kernelDims = 3, [0, 1, 2, 3, 4] => [2, 3, 4, 1, 0]
|
||||
for (int i = 0; i < weightDims; i++) {
|
||||
if (i < kernelDims)
|
||||
perm[i] = 2 + i;
|
||||
else
|
||||
perm[i] = nDims - i - 1;
|
||||
perm[i] = kernelDims + 1 - i;
|
||||
transposeShape[i] = weightShape[perm[i]];
|
||||
}
|
||||
auto reverseDim = llvm::to_vector<4>(llvm::seq<int64_t>(0, kernelDims));
|
||||
auto transposeTy =
|
||||
RankedTensorType::get(transposeShape, weightTy.getElementType());
|
||||
auto transposeOp = rewriter.create<stablehlo::TransposeOp>(
|
||||
op->getLoc(), transposeTy, weight, perm);
|
||||
auto reverseOp = rewriter.create<stablehlo::ReverseOp>(
|
||||
op->getLoc(), transposeOp, ArrayRef<int64_t>{0, 1});
|
||||
op->getLoc(), transposeOp, reverseDim);
|
||||
|
||||
// Prepare for transposed convolution
|
||||
SmallVector<int64_t> stablehloStrideVec(nSpatialDims, 1);
|
||||
|
|
|
@ -9110,6 +9110,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose3d.input\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -11797,10 +11807,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose3d.input\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -3633,6 +3633,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.conv_transpose1d to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAtenConvTranspose1dOp
|
||||
: public OpRewritePattern<AtenConvTranspose1dOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvTranspose1dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
||||
op.getStride(), op.getPadding(), op.getDilation(),
|
||||
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.conv_transpose2d to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAtenConvTranspose2dOp
|
||||
|
@ -3652,6 +3671,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.conv_transpose3d to aten.convolution
|
||||
namespace {
|
||||
class DecomposeAtenConvTranspose3dOp
|
||||
: public OpRewritePattern<AtenConvTranspose3dInputOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
||||
op.getStride(), op.getPadding(), op.getDilation(),
|
||||
/*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// The convolution backward op is decomposed as follows:
|
||||
// inputH, inputW = input.shape[2:]
|
||||
// output_padding_ = [
|
||||
|
@ -7963,7 +8001,9 @@ public:
|
|||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose3dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
|
||||
|
|
|
@ -428,7 +428,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenConv1dOp>();
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
target.addIllegalOp<AtenConv3dOp>();
|
||||
target.addIllegalOp<AtenConvTranspose1dOp>();
|
||||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||
target.addIllegalOp<AtenConvTranspose3dInputOp>();
|
||||
target.addIllegalOp<AtenArangeOp>();
|
||||
target.addIllegalOp<AtenArangeStartOp>();
|
||||
target.addIllegalOp<AtenLinspaceOp>();
|
||||
|
|
|
@ -911,6 +911,9 @@ STABLEHLO_PASS_SET = {
|
|||
"Convolution2DStaticModule_basic",
|
||||
"ConvolutionBackwardModule2DStatic_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"Conv_Transpose1dStaticModule_basic",
|
||||
"Conv_Transpose2dStaticModule_basic",
|
||||
"Conv_Transpose3dStaticModule_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
|
@ -2662,6 +2665,8 @@ ONNX_XFAIL_SET = {
|
|||
"PrimsIotaModule_basic",
|
||||
# Failure - unknown
|
||||
"BernoulliModule_basic",
|
||||
"Conv_Transpose1dModule_basic",
|
||||
"Conv_Transpose3dModule_basic",
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||
"CopyWithDifferentDTypesModule_basic",
|
||||
|
|
|
@ -1548,6 +1548,12 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option
|
|||
def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1)
|
||||
|
||||
def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]:
|
||||
return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups)
|
||||
|
||||
def aten〇conv_transpose3d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> List[int]:
|
||||
return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups)
|
||||
|
||||
def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]:
|
||||
return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)
|
||||
|
||||
|
@ -3538,6 +3544,10 @@ def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype:
|
|||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
def aten〇conv_transpose1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) +
|
||||
[Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)),
|
||||
|
@ -3549,6 +3559,10 @@ def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], w
|
|||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
def aten〇conv_transpose3d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
convolution_kwargs = {
|
||||
"stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1}
|
||||
@check_dtype_function(
|
||||
|
|
|
@ -760,6 +760,66 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils
|
|||
module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3))
|
||||
|
||||
|
||||
class Conv_Transpose1dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose1d(
|
||||
inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[2],
|
||||
padding=[1],
|
||||
dilation=[1],
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose1dModule())
|
||||
def Conv_Transpose1dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2))
|
||||
|
||||
|
||||
class Conv_Transpose1dStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 2, 6], torch.float32, True),
|
||||
([2, 5, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose1d(
|
||||
inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[2],
|
||||
padding=[1],
|
||||
dilation=[1],
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose1dStaticModule())
|
||||
def Conv_Transpose1dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2))
|
||||
|
||||
|
||||
class Conv_Transpose2dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -790,6 +850,96 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
|
||||
|
||||
|
||||
class Conv_Transpose2dStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 2, 5, 6], torch.float32, True),
|
||||
([2, 5, 2, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose2d(
|
||||
inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[2, 2],
|
||||
padding=[1, 1],
|
||||
dilation=[1, 1],
|
||||
output_padding=[0, 0],
|
||||
groups=1,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose2dStaticModule())
|
||||
def Conv_Transpose2dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2))
|
||||
|
||||
|
||||
class Conv_Transpose3dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose3d(
|
||||
inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[2, 2, 2],
|
||||
padding=[1, 1, 1],
|
||||
dilation=[1, 1, 1],
|
||||
output_padding=[0, 0, 0],
|
||||
groups=1,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose3dModule())
|
||||
def Conv_Transpose3dModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2))
|
||||
|
||||
|
||||
class Conv_Transpose3dStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([5, 2, 5, 6, 7], torch.float32, True),
|
||||
([2, 5, 2, 2, 2], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose3d(
|
||||
inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[2, 2, 2],
|
||||
padding=[1, 1, 1],
|
||||
dilation=[1, 1, 1],
|
||||
output_padding=[0, 0, 0],
|
||||
groups=1,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Conv_Transpose3dStaticModule())
|
||||
def Conv_Transpose3dStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2))
|
||||
|
||||
|
||||
class UpSampleNearest2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue