From 351f15424e67302bb0b173c367de4cd71c60ecb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Tue, 9 Aug 2022 09:50:07 +0800 Subject: [PATCH] [MHLO] Add transposed convolution conversion pattern (#1171) Co-authored-by: Bairen Yi Co-authored-by: Jiawei Wu Co-authored-by: Tianyou Guo Co-authored-by: Xu Yan Co-authored-by: Ziheng Jiang --- lib/Conversion/TorchToMhlo/Linear.cpp | 331 ++++++++++++++++++------ test/Conversion/TorchToMhlo/linear.mlir | 174 ++++++++++++- 2 files changed, 420 insertions(+), 85 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 09a3bc934..42dbd1798 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -379,84 +379,180 @@ public: } // namespace -// AtenConvolutionOp namespace { -class ConvertAtenConvlutionOp : public OpConversionPattern { +class ConvertAtenConvolutionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenConvolutionOp::Adaptor; - LogicalResult - matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value input = adaptor.input(); - Value weight = adaptor.weight(); + Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, + Value weight, int64_t groups) const { + auto weightTy = weight.getType().cast(); + auto weightElemTy = weightTy.getElementType(); + auto rank = weightTy.getRank(); + SmallVector weightShapeVec = + *mhlo::getDimSizesOfTensor(rewriter, op, weight); + auto weightShape = weightTy.getShape(); + SmallVector weightShapeInt(rank); + std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); - // The input shape is [N, C, H, W] - auto inputTy = input.getType().template cast(); - // The weight shape is [OC, (IC // groups), KH, KW] - // If tranposed is set to true, the weight shape changes to [IC, (OC // - // groups), KH, KW] - auto weightTy = weight.getType().template cast(); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + // 1. [IC, OC, H, W, ...] => [G, IC//G, OC, H, W, ...] + Value GValue = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(groups)); + Value ICDivGValue = rewriter.create( + op->getLoc(), weightShapeVec[0], GValue); + Value OCMulGValue = rewriter.create( + op->getLoc(), weightShapeVec[1], GValue); + weightShapeVec[0] = ICDivGValue; + weightShapeVec.insert(weightShapeVec.begin(), GValue); - if (!inputTy || !weightTy || !outTy) { - return op.emitError("input, weight and output must be ranked tensors"); + if (weightShapeInt[0] == ShapedType::kDynamicSize) { + weightShapeInt.insert(weightShapeInt.begin(), groups); + } else { + weightShapeInt[0] /= groups; + weightShapeInt.insert(weightShapeInt.begin(), groups); + } + Value weightShapeTensor = rewriter.create( + op->getLoc(), weightShapeVec); + weight = rewriter.create( + op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), + weight, weightShapeTensor); + + // 2. [G, IC//G, OC, H, W, ...] => [IC//G, G, OC, H, W, ...] + std::vector transposeDims(rank + 1); + for (int64_t i = 0; i <= rank; i++) + transposeDims[i] = i; + std::swap(transposeDims[1], transposeDims[0]); + weight = rewriter.create( + op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); + + // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] + weightShapeInt.erase(weightShapeInt.begin()); + if (weightShapeInt[1] != ShapedType::kDynamicSize) { + weightShapeInt[1] *= groups; + } + weightShapeVec.erase(weightShapeVec.begin()); + weightShapeVec[1] = OCMulGValue; + weightShapeTensor = rewriter.create( + op->getLoc(), weightShapeVec); + weight = rewriter.create( + op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), + weight, weightShapeTensor); + return weight; + } + + Value convertTransposedConv(AtenConvolutionOp op, + ConversionPatternRewriter &rewriter, + RankedTensorType outType, Value input, + Value weight, ArrayRef stride, + ArrayRef padding, + ArrayRef dilation, + ArrayRef outputPadding, int64_t groups, + bool needHandleOutputPadding) const { + auto inputTy = input.getType().cast(); + auto weightTy = weight.getType().cast(); + auto weightShape = weightTy.getShape(); + + auto nDims = inputTy.getRank(); + auto nSpatialDims = nDims - 2; + auto convOutTy = outType; + + if (needHandleOutputPadding) { + SmallVector outShape(nDims); + auto finalOutShape = outType.getShape(); + std::copy(finalOutShape.begin(), finalOutShape.end(), outShape.begin()); + for (int i = 2; i < nDims; ++i) { + if (finalOutShape[i] == ShapedType::kDynamicSize) + continue; + outShape[i] = finalOutShape[i] - outputPadding[i - 2]; + } + convOutTy = RankedTensorType::get(outShape, outType.getElementType()); } - if (inputTy.getRank() < 3) - return op.emitError("only input with at least 3 dims valid"); + // Prepare for transposed convolution + SmallVector mhloStrideVec(nSpatialDims, 1); + DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec); + SmallVector mhloPaddingVec(nSpatialDims * 2, 0); + for (int i = 0; i < nSpatialDims; ++i) { + int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; + mhloPaddingVec[i * 2] = padInt; + mhloPaddingVec[i * 2 + 1] = padInt; + } + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), + mhloPaddingVec); + SmallVector mhloLhsDilationVec(nSpatialDims); + std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin()); + DenseIntElementsAttr mhloLhsDilation = + rewriter.getI64TensorAttr(mhloLhsDilationVec); + SmallVector mhloRhsDilationVec(nSpatialDims); + std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin()); + DenseIntElementsAttr mhloRhsDilation = + rewriter.getI64TensorAttr(mhloRhsDilationVec); - SmallVector stride; - if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) { - return rewriter.notifyMatchFailure(op, - "non-const stride list unsupported"); + DenseElementsAttr windowReversal; + ArrayAttr precisionConfig; + + SmallVector spatialDims; + for (int i = 0; i < nSpatialDims; ++i) { + spatialDims.push_back(i + 2); + } + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( + /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, + /*inputFeatureDimension=*/1, + /*inputSpatialDimensions=*/spatialDims, + /*kernelInputFeatureDimension=*/0, + /*kernelOutputFeatureDimension=*/1, + /*kernelSpatialDimensions=*/spatialDims, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, + /*outputSpatialDimensions=*/spatialDims); + + // Reverse and transpose weight + weight = rewriter.create( + op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); + if (groups != 1) { + weight = reshapeConvWeight(rewriter, op, weight, groups); } - SmallVector padding; - if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) { - return rewriter.notifyMatchFailure(op, - "non-const padding list unsupported"); - } + // Create transposed convolution + auto transposedConvOp = rewriter.create( + op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + static_cast(groups), 1, precisionConfig); - SmallVector dilation; - if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) { - return rewriter.notifyMatchFailure(op, - "non-const dilation list unsupported"); - } - SmallVector outputPadding; - if (!matchPattern(op.output_padding(), - m_TorchConstantIntList(outputPadding))) { - return rewriter.notifyMatchFailure( - op, "non-const output_padding list unsupported"); - } - // Just ignore the outputPadding attribute - for (int64_t item : outputPadding) { - if (item != 0) - return rewriter.notifyMatchFailure( - op, "only zero output_padding list supported"); + // Handle output padding + if (!needHandleOutputPadding) { + return transposedConvOp.getResult(); } + SmallVector edgePaddingLowVec(nDims, 0); + SmallVector edgePaddingHighVec(nDims, 0); + SmallVector interiorPaddingVec(nDims, 0); + std::copy(outputPadding.begin(), outputPadding.end(), + edgePaddingHighVec.begin() + 2); + Value paddingValue = + mhlo::getConstTensor(rewriter, op, {0.0}, {}).getValue(); + paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy); + mlir::DenseIntElementsAttr edgePaddingLow = + rewriter.getI64VectorAttr(edgePaddingLowVec); + mlir::DenseIntElementsAttr edgePaddingHigh = + rewriter.getI64VectorAttr(edgePaddingHighVec); + mlir::DenseIntElementsAttr interiorPadding = + rewriter.getI64VectorAttr(interiorPaddingVec); - int64_t groups; - if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) { - return rewriter.notifyMatchFailure(op, "non-int groups unsupported"); - } + auto paddedOutput = rewriter.create( + op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, + edgePaddingHigh, interiorPadding); - bool transposed; - if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) { - return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported"); - } - if (transposed) { - return rewriter.notifyMatchFailure( - op, "only param tranposed of value 'false' supported!"); - } + return paddedOutput.getResult(); + } - assert(padding.size() == dilation.size() && - padding.size() == stride.size() && - padding.size() == static_cast(inputTy.getRank()) - 2); - int64_t nSpatialDims = padding.size(); + Value convertNormalConv(AtenConvolutionOp op, + ConversionPatternRewriter &rewriter, + RankedTensorType outType, Value input, Value weight, + ArrayRef stride, ArrayRef padding, + ArrayRef dilation, int64_t groups) const { + int64_t nDims = outType.getRank(); // Get mhlo::ConvolutionOp attributes DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( @@ -468,20 +564,17 @@ public: mhloPaddingVec.emplace_back(padding[i]); mhloPaddingVec.emplace_back(padding[i]); } - DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), mhloPaddingVec); - DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(dilation.size())}, rewriter.getI64Type()), dilation); - SmallVector spatialDimensions; - for (int64_t i = 2; i < inputTy.getRank(); i++) { + for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); } mhlo::ConvDimensionNumbersAttr dimensionNumbers = @@ -495,25 +588,113 @@ public: /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputSpatialDimensions=*/spatialDimensions); - IntegerAttr featureGroupCount = - IntegerAttr::get(rewriter.getI64Type(), groups); - IntegerAttr batchGroupCount = IntegerAttr::get(rewriter.getI64Type(), 1); - // mhlo::ConvolutionOp's optional attributes, leave them as default DenseIntElementsAttr mhloLhsDilation; DenseElementsAttr windowReversal; ArrayAttr precisionConfig; auto mhloConvOp = rewriter.create( - op->getLoc(), outTy, input, weight, mhloWindowStride, mhloPadding, + op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding, mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, - featureGroupCount, batchGroupCount, precisionConfig); + static_cast(groups), 1, precisionConfig); + + return mhloConvOp.getResult(); + } + + LogicalResult + matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.input(); + Value weight = adaptor.weight(); + + // The input shape is [N, C, H, W] + auto inputTy = input.getType().template cast(); + // The weight shape is [OC, (IC//G), KH, KW] + // If transposed is set to true, + // the weight shape changes to [IC, (OC//G), KH, KW] + auto weightTy = weight.getType().template cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template cast(); + if (!inputTy || !weightTy || !outTy) { + return op.emitError("input, weight and output must be ranked tensors"); + } + if (inputTy.getRank() < 3) + return op.emitError("only input with at least 3 dims valid"); + SmallVector stride; + if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) { + return rewriter.notifyMatchFailure(op, + "non-const stride list unsupported"); + } + SmallVector padding; + if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) { + return rewriter.notifyMatchFailure(op, + "non-const padding list unsupported"); + } + SmallVector dilation; + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) { + return rewriter.notifyMatchFailure(op, + "non-const dilation list unsupported"); + } + SmallVector outputPadding; + if (!matchPattern(op.output_padding(), + m_TorchConstantIntList(outputPadding))) { + return rewriter.notifyMatchFailure( + op, "non-const output_padding list unsupported"); + } + int64_t groups; + if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) { + return rewriter.notifyMatchFailure(op, "non-int groups unsupported"); + } + bool transposed; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) { + return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported"); + } + // Whether need to handle outputpadding + bool needHandleOutputPadding = false; + for (int64_t i : outputPadding) { + if (i != 0) { + needHandleOutputPadding = true; + break; + } + } + // Op validation check + if (needHandleOutputPadding && !transposed) { + return op->emitError( + "output padding attr is valid only in transposed convolution"); + } + assert(padding.size() == dilation.size() && + padding.size() == stride.size() && + padding.size() == static_cast(inputTy.getRank()) - 2 && + inputTy.getRank() == weightTy.getRank()); + + auto nSpatialDims = padding.size(); + auto nDims = inputTy.getRank(); + + // Kernel size must be constant. + auto weightShape = weightTy.getShape(); + for (int i = 2; i < nDims; ++i) { + if (weightShape[i] == ShapedType::kDynamicSize) { + return rewriter.notifyMatchFailure( + op, "only constant kernel size is supported"); + } + } + + Value mhloConvResult; + if (transposed) { + mhloConvResult = convertTransposedConv( + op, rewriter, outTy, input, weight, stride, padding, dilation, + outputPadding, groups, needHandleOutputPadding); + } else { + mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, + stride, padding, dilation, groups); + } auto bias = adaptor.bias(); // No bias provided if (failed(checkNotNone(rewriter, op, op.bias()))) { - rewriter.replaceOp(op, mhloConvOp.getResult()); + rewriter.replaceOp(op, mhloConvResult); return success(); } @@ -537,8 +718,8 @@ public: bias = mhlo::promoteType(rewriter, bias, outTy); DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, mhloConvOp.getResult(), bias, bcastDimensions); + rewriter.replaceOpWithNewOp(op, outTy, mhloConvResult, + bias, bcastDimensions); return success(); } }; @@ -570,7 +751,7 @@ void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( #define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp); #undef INSERT_CONVOLUTION_ATENOP_PATTERN } diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index d39be34e0..6c42f7af0 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -270,9 +270,9 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.convolution( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, -// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor // CHECK: %[[T_2:.*]] = torch.constant.none // CHECK: %[[T_4:.*]] = torch.constant.int 2 // CHECK: %[[T_5:.*]] = torch.constant.int 1 @@ -285,10 +285,10 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[T_13:.*]] = torch.constant.bool false // CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) -// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %none = torch.constant.none %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 @@ -299,17 +299,17 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list %4 = torch.prim.ListConstruct : () -> !torch.list %false = torch.constant.bool false - %5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + %5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> return %5 : !torch.vtensor<[?,?,?,?],f32> } // ----- // CHECK-LABEL: func.func @torch.aten.convolution$bias( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, // CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor // CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor // CHECK: %int2 = torch.constant.int 2 // CHECK: %int1 = torch.constant.int 1 @@ -322,7 +322,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %false = torch.constant.bool false // CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) -// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor // CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 @@ -332,7 +332,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 %int4 = torch.constant.int 4 @@ -342,6 +342,160 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list %4 = torch.prim.ListConstruct : () -> !torch.list %false = torch.constant.bool false - %5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + %5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.vtensor<[?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> return %5 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> +// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> +// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> +func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.convolution %arg0, %arg1, %none, %1, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,9,9],f32> + return %2 : !torch.vtensor<[1,4,9,9],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> +func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> + return %3 : !torch.vtensor<[1,4,15,15],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> +// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> +// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> +func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %1, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,16,16],f32> + return %3 : !torch.vtensor<[1,4,16,16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %none = torch.constant.none +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2 +// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32> +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 +// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index +// CHECK: %[[T_9:.*]] = tensor.dim %[[T_6]], %[[IDX_1]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 +// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index +// CHECK: %[[T_11:.*]] = tensor.dim %[[T_6]], %[[IDX_2]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_12:.*]] = arith.index_cast %[[T_11]] : index to i64 +// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index +// CHECK: %[[T_13:.*]] = tensor.dim %[[T_6]], %[[IDX_3]] : tensor<2x2x3x3xf32> +// CHECK: %[[T_14:.*]] = arith.index_cast %[[T_13]] : index to i64 +// CHECK: %[[T_24:.*]] = arith.constant 2 : i64 +// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 +// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 +// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> +// CHECK: %[[T_18:.*]] = "mhlo.dynamic_reshape"(%[[T_6]], %[[T_17]]) : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> +// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> +// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> +// CHECK: %[[T_21:.*]] = "mhlo.dynamic_reshape"(%[[T_19]], %[[T_20]]) : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> +// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> +func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,4,15,15],f32> + return %3 : !torch.vtensor<[1,4,15,15],f32> } \ No newline at end of file