From c94431f71cd20760cfa4d0f0e520ef6c53738184 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E5=AE=B6=E4=BC=9F?= <73166454+Vremold@users.noreply.github.com> Date: Thu, 4 Aug 2022 15:41:35 +0800 Subject: [PATCH] [MHLO] Add convolution op pattern (#1152) Co-authored-by: Bairen Yi Co-authored-by: Jiawei Wu Co-authored-by: Tianyou Guo Co-authored-by: Xu Yan Co-authored-by: Ziheng Jiang --- lib/Conversion/TorchToMhlo/Linear.cpp | 171 ++++++++++++++++++++++++ test/Conversion/TorchToMhlo/linear.mlir | 79 +++++++++++ 2 files changed, 250 insertions(+) diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index d9475e8ef..09a3bc934 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -379,6 +379,171 @@ public: } // namespace +// AtenConvolutionOp +namespace { +class ConvertAtenConvlutionOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenConvolutionOp::Adaptor; + + LogicalResult + matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.input(); + Value weight = adaptor.weight(); + + // The input shape is [N, C, H, W] + auto inputTy = input.getType().template cast(); + // The weight shape is [OC, (IC // groups), KH, KW] + // If tranposed is set to true, the weight shape changes to [IC, (OC // + // groups), KH, KW] + auto weightTy = weight.getType().template cast(); + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template cast(); + + if (!inputTy || !weightTy || !outTy) { + return op.emitError("input, weight and output must be ranked tensors"); + } + + if (inputTy.getRank() < 3) + return op.emitError("only input with at least 3 dims valid"); + + SmallVector stride; + if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) { + return rewriter.notifyMatchFailure(op, + "non-const stride list unsupported"); + } + + SmallVector padding; + if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) { + return rewriter.notifyMatchFailure(op, + "non-const padding list unsupported"); + } + + SmallVector dilation; + if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) { + return rewriter.notifyMatchFailure(op, + "non-const dilation list unsupported"); + } + SmallVector outputPadding; + if (!matchPattern(op.output_padding(), + m_TorchConstantIntList(outputPadding))) { + return rewriter.notifyMatchFailure( + op, "non-const output_padding list unsupported"); + } + // Just ignore the outputPadding attribute + for (int64_t item : outputPadding) { + if (item != 0) + return rewriter.notifyMatchFailure( + op, "only zero output_padding list supported"); + } + + int64_t groups; + if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) { + return rewriter.notifyMatchFailure(op, "non-int groups unsupported"); + } + + bool transposed; + if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) { + return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported"); + } + if (transposed) { + return rewriter.notifyMatchFailure( + op, "only param tranposed of value 'false' supported!"); + } + + assert(padding.size() == dilation.size() && + padding.size() == stride.size() && + padding.size() == static_cast(inputTy.getRank()) - 2); + int64_t nSpatialDims = padding.size(); + + // Get mhlo::ConvolutionOp attributes + DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(stride.size())}, + rewriter.getI64Type()), + stride); + std::vector mhloPaddingVec; + for (size_t i = 0; i < padding.size(); i++) { + mhloPaddingVec.emplace_back(padding[i]); + mhloPaddingVec.emplace_back(padding[i]); + } + + DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(padding.size()), static_cast(2)}, + rewriter.getI64Type()), + mhloPaddingVec); + + DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(dilation.size())}, + rewriter.getI64Type()), + dilation); + + SmallVector spatialDimensions; + for (int64_t i = 2; i < inputTy.getRank(); i++) { + spatialDimensions.emplace_back(i); + } + mhlo::ConvDimensionNumbersAttr dimensionNumbers = + mhlo::ConvDimensionNumbersAttr::get( + /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, + /*inputFeatureDimension=*/1, + /*inputSpatialDimensions=*/spatialDimensions, + /*kernelInputFeatureDimension=*/1, + /*kernelOutputFeatureDimension=*/0, + /*kernelSpatialDimensions=*/spatialDimensions, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, + /*outputSpatialDimensions=*/spatialDimensions); + + IntegerAttr featureGroupCount = + IntegerAttr::get(rewriter.getI64Type(), groups); + IntegerAttr batchGroupCount = IntegerAttr::get(rewriter.getI64Type(), 1); + + // mhlo::ConvolutionOp's optional attributes, leave them as default + DenseIntElementsAttr mhloLhsDilation; + DenseElementsAttr windowReversal; + ArrayAttr precisionConfig; + + auto mhloConvOp = rewriter.create( + op->getLoc(), outTy, input, weight, mhloWindowStride, mhloPadding, + mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, + featureGroupCount, batchGroupCount, precisionConfig); + + auto bias = adaptor.bias(); + + // No bias provided + if (failed(checkNotNone(rewriter, op, op.bias()))) { + rewriter.replaceOp(op, mhloConvOp.getResult()); + return success(); + } + + // Handle bias + if (!bias.getType().cast()) { + return op.emitError("bias provided but not a ranked tensor"); + } + + auto biasTy = bias.getType().template cast(); + if (!biasTy.getElementType().isIntOrFloat()) { + return op.emitError("only floating-point or integer datatype " + "legalization for bias supported"); + } + + assert(biasTy.getRank() <= 1); + + // Reshape and promote bias + auto inputUnsqzDims = + llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); + bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims); + bias = mhlo::promoteType(rewriter, bias, outTy); + + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, mhloConvOp.getResult(), bias, bcastDimensions); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -402,4 +567,10 @@ void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( patterns.add>(typeConverter, context); INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN + +#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add(typeConverter, context); + INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp); +#undef INSERT_CONVOLUTION_ATENOP_PATTERN } diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index 18ea97654..d39be34e0 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -266,3 +266,82 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten return %1 : !torch.vtensor<[?,256],f32> } +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_2:.*]] = torch.constant.none +// CHECK: %[[T_4:.*]] = torch.constant.int 2 +// CHECK: %[[T_5:.*]] = torch.constant.int 1 +// CHECK: %[[T_6:.*]] = torch.constant.int 4 +// CHECK: %[[T_7:.*]] = torch.constant.int 3 +// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]] +// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[T_13:.*]] = torch.constant.bool false +// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %none = torch.constant.none + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %false = torch.constant.bool false + %5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$bias( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3 +// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %false = torch.constant.bool false +// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor +// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 +// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> +// CHECK: %[[T_12:.*]] = "mhlo.dynamic_reshape"(%[[T_2]], %[[T_11]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor +// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %false = torch.constant.bool false + %5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} \ No newline at end of file