mirror of https://github.com/llvm/torch-mlir
[MHLO] Add convolution op pattern (#1152)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1156/head snapshot-20220804.554
parent
08fc2d89bb
commit
c94431f71c
|
@ -379,6 +379,171 @@ public:
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// AtenConvolutionOp
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenConvlutionOp : public OpConversionPattern<AtenConvolutionOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Value input = adaptor.input();
|
||||||
|
Value weight = adaptor.weight();
|
||||||
|
|
||||||
|
// The input shape is [N, C, H, W]
|
||||||
|
auto inputTy = input.getType().template cast<RankedTensorType>();
|
||||||
|
// The weight shape is [OC, (IC // groups), KH, KW]
|
||||||
|
// If tranposed is set to true, the weight shape changes to [IC, (OC //
|
||||||
|
// groups), KH, KW]
|
||||||
|
auto weightTy = weight.getType().template cast<RankedTensorType>();
|
||||||
|
auto outTy = getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!inputTy || !weightTy || !outTy) {
|
||||||
|
return op.emitError("input, weight and output must be ranked tensors");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputTy.getRank() < 3)
|
||||||
|
return op.emitError("only input with at least 3 dims valid");
|
||||||
|
|
||||||
|
SmallVector<int64_t> stride;
|
||||||
|
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const stride list unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> padding;
|
||||||
|
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const padding list unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> dilation;
|
||||||
|
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const dilation list unsupported");
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> outputPadding;
|
||||||
|
if (!matchPattern(op.output_padding(),
|
||||||
|
m_TorchConstantIntList(outputPadding))) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "non-const output_padding list unsupported");
|
||||||
|
}
|
||||||
|
// Just ignore the outputPadding attribute
|
||||||
|
for (int64_t item : outputPadding) {
|
||||||
|
if (item != 0)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only zero output_padding list supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t groups;
|
||||||
|
if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-int groups unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool transposed;
|
||||||
|
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported");
|
||||||
|
}
|
||||||
|
if (transposed) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only param tranposed of value 'false' supported!");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(padding.size() == dilation.size() &&
|
||||||
|
padding.size() == stride.size() &&
|
||||||
|
padding.size() == static_cast<size_t>(inputTy.getRank()) - 2);
|
||||||
|
int64_t nSpatialDims = padding.size();
|
||||||
|
|
||||||
|
// Get mhlo::ConvolutionOp attributes
|
||||||
|
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({static_cast<long int>(stride.size())},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
stride);
|
||||||
|
std::vector<int64_t> mhloPaddingVec;
|
||||||
|
for (size_t i = 0; i < padding.size(); i++) {
|
||||||
|
mhloPaddingVec.emplace_back(padding[i]);
|
||||||
|
mhloPaddingVec.emplace_back(padding[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get(
|
||||||
|
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
mhloPaddingVec);
|
||||||
|
|
||||||
|
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
|
||||||
|
RankedTensorType::get({static_cast<long int>(dilation.size())},
|
||||||
|
rewriter.getI64Type()),
|
||||||
|
dilation);
|
||||||
|
|
||||||
|
SmallVector<int64_t> spatialDimensions;
|
||||||
|
for (int64_t i = 2; i < inputTy.getRank(); i++) {
|
||||||
|
spatialDimensions.emplace_back(i);
|
||||||
|
}
|
||||||
|
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
|
||||||
|
mhlo::ConvDimensionNumbersAttr::get(
|
||||||
|
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
|
||||||
|
/*inputFeatureDimension=*/1,
|
||||||
|
/*inputSpatialDimensions=*/spatialDimensions,
|
||||||
|
/*kernelInputFeatureDimension=*/1,
|
||||||
|
/*kernelOutputFeatureDimension=*/0,
|
||||||
|
/*kernelSpatialDimensions=*/spatialDimensions,
|
||||||
|
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
|
||||||
|
/*outputSpatialDimensions=*/spatialDimensions);
|
||||||
|
|
||||||
|
IntegerAttr featureGroupCount =
|
||||||
|
IntegerAttr::get(rewriter.getI64Type(), groups);
|
||||||
|
IntegerAttr batchGroupCount = IntegerAttr::get(rewriter.getI64Type(), 1);
|
||||||
|
|
||||||
|
// mhlo::ConvolutionOp's optional attributes, leave them as default
|
||||||
|
DenseIntElementsAttr mhloLhsDilation;
|
||||||
|
DenseElementsAttr windowReversal;
|
||||||
|
ArrayAttr precisionConfig;
|
||||||
|
|
||||||
|
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
|
||||||
|
op->getLoc(), outTy, input, weight, mhloWindowStride, mhloPadding,
|
||||||
|
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
|
||||||
|
featureGroupCount, batchGroupCount, precisionConfig);
|
||||||
|
|
||||||
|
auto bias = adaptor.bias();
|
||||||
|
|
||||||
|
// No bias provided
|
||||||
|
if (failed(checkNotNone(rewriter, op, op.bias()))) {
|
||||||
|
rewriter.replaceOp(op, mhloConvOp.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle bias
|
||||||
|
if (!bias.getType().cast<RankedTensorType>()) {
|
||||||
|
return op.emitError("bias provided but not a ranked tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto biasTy = bias.getType().template cast<RankedTensorType>();
|
||||||
|
if (!biasTy.getElementType().isIntOrFloat()) {
|
||||||
|
return op.emitError("only floating-point or integer datatype "
|
||||||
|
"legalization for bias supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(biasTy.getRank() <= 1);
|
||||||
|
|
||||||
|
// Reshape and promote bias
|
||||||
|
auto inputUnsqzDims =
|
||||||
|
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||||
|
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
||||||
|
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||||
|
|
||||||
|
DenseIntElementsAttr bcastDimensions;
|
||||||
|
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||||
|
op, outTy, mhloConvOp.getResult(), bias, bcastDimensions);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -402,4 +567,10 @@ void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
|
||||||
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
|
||||||
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
#undef INSERT_LINEAR_ATEMOP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenConvlutionOp>(typeConverter, context);
|
||||||
|
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
|
||||||
|
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
|
||||||
}
|
}
|
||||||
|
|
|
@ -266,3 +266,82 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
|
||||||
return %1 : !torch.vtensor<[?,256],f32>
|
return %1 : !torch.vtensor<[?,256],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.convolution(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_2:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[T_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[T_5:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[T_6:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[T_7:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]]
|
||||||
|
// CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_13:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
||||||
|
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %5 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.convolution$bias(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
|
||||||
|
// CHECK: %int2 = torch.constant.int 2
|
||||||
|
// CHECK: %int1 = torch.constant.int 1
|
||||||
|
// CHECK: %int4 = torch.constant.int 4
|
||||||
|
// CHECK: %int3 = torch.constant.int 3
|
||||||
|
// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3
|
||||||
|
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %false = torch.constant.bool false
|
||||||
|
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
|
||||||
|
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
|
||||||
|
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
|
||||||
|
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
|
||||||
|
// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64>
|
||||||
|
// CHECK: %[[T_12:.*]] = "mhlo.dynamic_reshape"(%[[T_2]], %[[T_11]]) : (tensor<?xf32>, tensor<3xi64>) -> tensor<?x1x1xf32>
|
||||||
|
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%2 = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %5 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue