[MHLO] Add transposed convolution conversion pattern (#1171)

Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1187/head
武家伟 2022-08-09 09:50:07 +08:00 committed by GitHub
parent 504de5e701
commit 351f15424e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 420 additions and 85 deletions

View File

@ -379,84 +379,180 @@ public:
} // namespace
// AtenConvolutionOp
namespace {
class ConvertAtenConvlutionOp : public OpConversionPattern<AtenConvolutionOp> {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
using OpConversionPattern<AtenConvolutionOp>::OpConversionPattern;
using OpAdaptor = typename AtenConvolutionOp::Adaptor;
LogicalResult
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.input();
Value weight = adaptor.weight();
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
Value weight, int64_t groups) const {
auto weightTy = weight.getType().cast<RankedTensorType>();
auto weightElemTy = weightTy.getElementType();
auto rank = weightTy.getRank();
SmallVector<Value> weightShapeVec =
*mhlo::getDimSizesOfTensor(rewriter, op, weight);
auto weightShape = weightTy.getShape();
SmallVector<int64_t> weightShapeInt(rank);
std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin());
// The input shape is [N, C, H, W]
auto inputTy = input.getType().template cast<RankedTensorType>();
// The weight shape is [OC, (IC // groups), KH, KW]
// If tranposed is set to true, the weight shape changes to [IC, (OC //
// groups), KH, KW]
auto weightTy = weight.getType().template cast<RankedTensorType>();
auto outTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
// 1. [IC, OC, H, W, ...] => [G, IC//G, OC, H, W, ...]
Value GValue = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getI64IntegerAttr(groups));
Value ICDivGValue = rewriter.create<mlir::arith::DivSIOp>(
op->getLoc(), weightShapeVec[0], GValue);
Value OCMulGValue = rewriter.create<mlir::arith::MulIOp>(
op->getLoc(), weightShapeVec[1], GValue);
weightShapeVec[0] = ICDivGValue;
weightShapeVec.insert(weightShapeVec.begin(), GValue);
if (!inputTy || !weightTy || !outTy) {
return op.emitError("input, weight and output must be ranked tensors");
if (weightShapeInt[0] == ShapedType::kDynamicSize) {
weightShapeInt.insert(weightShapeInt.begin(), groups);
} else {
weightShapeInt[0] /= groups;
weightShapeInt.insert(weightShapeInt.begin(), groups);
}
Value weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
// 2. [G, IC//G, OC, H, W, ...] => [IC//G, G, OC, H, W, ...]
std::vector<int64_t> transposeDims(rank + 1);
for (int64_t i = 0; i <= rank; i++)
transposeDims[i] = i;
std::swap(transposeDims[1], transposeDims[0]);
weight = rewriter.create<mhlo::TransposeOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims));
// 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...]
weightShapeInt.erase(weightShapeInt.begin());
if (weightShapeInt[1] != ShapedType::kDynamicSize) {
weightShapeInt[1] *= groups;
}
weightShapeVec.erase(weightShapeVec.begin());
weightShapeVec[1] = OCMulGValue;
weightShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), weightShapeVec);
weight = rewriter.create<mhlo::DynamicReshapeOp>(
op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy),
weight, weightShapeTensor);
return weight;
}
Value convertTransposedConv(AtenConvolutionOp op,
ConversionPatternRewriter &rewriter,
RankedTensorType outType, Value input,
Value weight, ArrayRef<int64_t> stride,
ArrayRef<int64_t> padding,
ArrayRef<int64_t> dilation,
ArrayRef<int64_t> outputPadding, int64_t groups,
bool needHandleOutputPadding) const {
auto inputTy = input.getType().cast<RankedTensorType>();
auto weightTy = weight.getType().cast<RankedTensorType>();
auto weightShape = weightTy.getShape();
auto nDims = inputTy.getRank();
auto nSpatialDims = nDims - 2;
auto convOutTy = outType;
if (needHandleOutputPadding) {
SmallVector<int64_t> outShape(nDims);
auto finalOutShape = outType.getShape();
std::copy(finalOutShape.begin(), finalOutShape.end(), outShape.begin());
for (int i = 2; i < nDims; ++i) {
if (finalOutShape[i] == ShapedType::kDynamicSize)
continue;
outShape[i] = finalOutShape[i] - outputPadding[i - 2];
}
convOutTy = RankedTensorType::get(outShape, outType.getElementType());
}
if (inputTy.getRank() < 3)
return op.emitError("only input with at least 3 dims valid");
// Prepare for transposed convolution
SmallVector<int64_t> mhloStrideVec(nSpatialDims, 1);
DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec);
SmallVector<int64_t> mhloPaddingVec(nSpatialDims * 2, 0);
for (int i = 0; i < nSpatialDims; ++i) {
int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i];
mhloPaddingVec[i * 2] = padInt;
mhloPaddingVec[i * 2 + 1] = padInt;
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()),
mhloPaddingVec);
SmallVector<int64_t> mhloLhsDilationVec(nSpatialDims);
std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin());
DenseIntElementsAttr mhloLhsDilation =
rewriter.getI64TensorAttr(mhloLhsDilationVec);
SmallVector<int64_t> mhloRhsDilationVec(nSpatialDims);
std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin());
DenseIntElementsAttr mhloRhsDilation =
rewriter.getI64TensorAttr(mhloRhsDilationVec);
SmallVector<int64_t> stride;
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) {
return rewriter.notifyMatchFailure(op,
"non-const stride list unsupported");
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
SmallVector<int64_t> spatialDims;
for (int i = 0; i < nSpatialDims; ++i) {
spatialDims.push_back(i + 2);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
mhlo::ConvDimensionNumbersAttr::get(
/*context=*/rewriter.getContext(), /*inputBatchDimension=*/0,
/*inputFeatureDimension=*/1,
/*inputSpatialDimensions=*/spatialDims,
/*kernelInputFeatureDimension=*/0,
/*kernelOutputFeatureDimension=*/1,
/*kernelSpatialDimensions=*/spatialDims,
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDims);
// Reverse and transpose weight
weight = rewriter.create<mhlo::ReverseOp>(
op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims));
if (groups != 1) {
weight = reshapeConvWeight(rewriter, op, weight, groups);
}
SmallVector<int64_t> padding;
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) {
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
}
// Create transposed convolution
auto transposedConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
static_cast<uint64_t>(groups), 1, precisionConfig);
SmallVector<int64_t> dilation;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) {
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");
}
SmallVector<int64_t> outputPadding;
if (!matchPattern(op.output_padding(),
m_TorchConstantIntList(outputPadding))) {
return rewriter.notifyMatchFailure(
op, "non-const output_padding list unsupported");
}
// Just ignore the outputPadding attribute
for (int64_t item : outputPadding) {
if (item != 0)
return rewriter.notifyMatchFailure(
op, "only zero output_padding list supported");
// Handle output padding
if (!needHandleOutputPadding) {
return transposedConvOp.getResult();
}
SmallVector<int64_t> edgePaddingLowVec(nDims, 0);
SmallVector<int64_t> edgePaddingHighVec(nDims, 0);
SmallVector<int64_t> interiorPaddingVec(nDims, 0);
std::copy(outputPadding.begin(), outputPadding.end(),
edgePaddingHighVec.begin() + 2);
Value paddingValue =
mhlo::getConstTensor<float>(rewriter, op, {0.0}, {}).getValue();
paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy);
mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh =
rewriter.getI64VectorAttr(edgePaddingHighVec);
mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec);
int64_t groups;
if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) {
return rewriter.notifyMatchFailure(op, "non-int groups unsupported");
}
auto paddedOutput = rewriter.create<mhlo::PadOp>(
op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow,
edgePaddingHigh, interiorPadding);
bool transposed;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) {
return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported");
}
if (transposed) {
return rewriter.notifyMatchFailure(
op, "only param tranposed of value 'false' supported!");
}
return paddedOutput.getResult();
}
assert(padding.size() == dilation.size() &&
padding.size() == stride.size() &&
padding.size() == static_cast<size_t>(inputTy.getRank()) - 2);
int64_t nSpatialDims = padding.size();
Value convertNormalConv(AtenConvolutionOp op,
ConversionPatternRewriter &rewriter,
RankedTensorType outType, Value input, Value weight,
ArrayRef<int64_t> stride, ArrayRef<int64_t> padding,
ArrayRef<int64_t> dilation, int64_t groups) const {
int64_t nDims = outType.getRank();
// Get mhlo::ConvolutionOp attributes
DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get(
@ -468,20 +564,17 @@ public:
mhloPaddingVec.emplace_back(padding[i]);
mhloPaddingVec.emplace_back(padding[i]);
}
DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get(
RankedTensorType::get(
{static_cast<long int>(padding.size()), static_cast<long int>(2)},
rewriter.getI64Type()),
mhloPaddingVec);
DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<long int>(dilation.size())},
rewriter.getI64Type()),
dilation);
SmallVector<int64_t> spatialDimensions;
for (int64_t i = 2; i < inputTy.getRank(); i++) {
for (int64_t i = 2; i < nDims; i++) {
spatialDimensions.emplace_back(i);
}
mhlo::ConvDimensionNumbersAttr dimensionNumbers =
@ -495,25 +588,113 @@ public:
/*outputBatchDimension=*/0, /*outputFeatureDimension=*/1,
/*outputSpatialDimensions=*/spatialDimensions);
IntegerAttr featureGroupCount =
IntegerAttr::get(rewriter.getI64Type(), groups);
IntegerAttr batchGroupCount = IntegerAttr::get(rewriter.getI64Type(), 1);
// mhlo::ConvolutionOp's optional attributes, leave them as default
DenseIntElementsAttr mhloLhsDilation;
DenseElementsAttr windowReversal;
ArrayAttr precisionConfig;
auto mhloConvOp = rewriter.create<mhlo::ConvolutionOp>(
op->getLoc(), outTy, input, weight, mhloWindowStride, mhloPadding,
op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding,
mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers,
featureGroupCount, batchGroupCount, precisionConfig);
static_cast<uint64_t>(groups), 1, precisionConfig);
return mhloConvOp.getResult();
}
LogicalResult
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.input();
Value weight = adaptor.weight();
// The input shape is [N, C, H, W]
auto inputTy = input.getType().template cast<RankedTensorType>();
// The weight shape is [OC, (IC//G), KH, KW]
// If transposed is set to true,
// the weight shape changes to [IC, (OC//G), KH, KW]
auto weightTy = weight.getType().template cast<RankedTensorType>();
auto outTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
if (!inputTy || !weightTy || !outTy) {
return op.emitError("input, weight and output must be ranked tensors");
}
if (inputTy.getRank() < 3)
return op.emitError("only input with at least 3 dims valid");
SmallVector<int64_t> stride;
if (!matchPattern(op.stride(), m_TorchConstantIntList(stride))) {
return rewriter.notifyMatchFailure(op,
"non-const stride list unsupported");
}
SmallVector<int64_t> padding;
if (!matchPattern(op.padding(), m_TorchConstantIntList(padding))) {
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
}
SmallVector<int64_t> dilation;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilation))) {
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");
}
SmallVector<int64_t> outputPadding;
if (!matchPattern(op.output_padding(),
m_TorchConstantIntList(outputPadding))) {
return rewriter.notifyMatchFailure(
op, "non-const output_padding list unsupported");
}
int64_t groups;
if (!matchPattern(op.groups(), m_TorchConstantInt(&groups))) {
return rewriter.notifyMatchFailure(op, "non-int groups unsupported");
}
bool transposed;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed))) {
return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported");
}
// Whether need to handle outputpadding
bool needHandleOutputPadding = false;
for (int64_t i : outputPadding) {
if (i != 0) {
needHandleOutputPadding = true;
break;
}
}
// Op validation check
if (needHandleOutputPadding && !transposed) {
return op->emitError(
"output padding attr is valid only in transposed convolution");
}
assert(padding.size() == dilation.size() &&
padding.size() == stride.size() &&
padding.size() == static_cast<size_t>(inputTy.getRank()) - 2 &&
inputTy.getRank() == weightTy.getRank());
auto nSpatialDims = padding.size();
auto nDims = inputTy.getRank();
// Kernel size must be constant.
auto weightShape = weightTy.getShape();
for (int i = 2; i < nDims; ++i) {
if (weightShape[i] == ShapedType::kDynamicSize) {
return rewriter.notifyMatchFailure(
op, "only constant kernel size is supported");
}
}
Value mhloConvResult;
if (transposed) {
mhloConvResult = convertTransposedConv(
op, rewriter, outTy, input, weight, stride, padding, dilation,
outputPadding, groups, needHandleOutputPadding);
} else {
mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight,
stride, padding, dilation, groups);
}
auto bias = adaptor.bias();
// No bias provided
if (failed(checkNotNone(rewriter, op, op.bias()))) {
rewriter.replaceOp(op, mhloConvOp.getResult());
rewriter.replaceOp(op, mhloConvResult);
return success();
}
@ -537,8 +718,8 @@ public:
bias = mhlo::promoteType(rewriter, bias, outTy);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, outTy, mhloConvOp.getResult(), bias, bcastDimensions);
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
bias, bcastDimensions);
return success();
}
};
@ -570,7 +751,7 @@ void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality(
#define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenConvlutionOp>(typeConverter, context);
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp);
#undef INSERT_CONVOLUTION_ATENOP_PATTERN
}

View File

@ -270,9 +270,9 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK-LABEL: func.func @torch.aten.convolution(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor<?x?x3x3xf32>
// CHECK: %[[T_2:.*]] = torch.constant.none
// CHECK: %[[T_4:.*]] = torch.constant.int 2
// CHECK: %[[T_5:.*]] = torch.constant.int 1
@ -285,10 +285,10 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten
// CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[T_13:.*]] = torch.constant.bool false
// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%none = torch.constant.none
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
@ -299,17 +299,17 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
%3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
%false = torch.constant.bool false
%5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
%5 = torch.aten.convolution %arg0, %arg1, %none, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.convolution$bias(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,?,?],f32>,
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>,
// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor<?x?x3x3xf32>
// CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor<?xf32>
// CHECK: %int2 = torch.constant.int 2
// CHECK: %int1 = torch.constant.int 1
@ -322,7 +322,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %false = torch.constant.bool false
// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?x3x3xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor<?xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
@ -332,7 +332,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !
// CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor<?x?x?x?xf32>, tensor<?x1x1xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,3,3],f32>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
@ -342,6 +342,160 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar
%3 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
%false = torch.constant.bool false
%5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
%5 = torch.aten.convolution %arg0, %arg1, %arg2, %1, %2, %3, %false, %4, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,3,3],f32>, !torch.vtensor<[?],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32>
// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32>
// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32>
func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.convolution %arg0, %arg1, %none, %1, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,9,9],f32>
return %2 : !torch.vtensor<[1,4,9,9],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32>
func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
return %3 : !torch.vtensor<[1,4,15,15],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32>
// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor<f32>) -> tensor<1x4x16x16xf32>
// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32>
// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32>
func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %1, %int1 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,16,16],f32>
return %3 : !torch.vtensor<[1,4,16,16],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> {
// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32>
// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32>
// CHECK: %true = torch.constant.bool true
// CHECK: %none = torch.constant.none
// CHECK: %int0 = torch.constant.int 0
// CHECK: %int1 = torch.constant.int 1
// CHECK: %int2 = torch.constant.int 2
// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2
// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32>
// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64
// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index
// CHECK: %[[T_9:.*]] = tensor.dim %[[T_6]], %[[IDX_1]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64
// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index
// CHECK: %[[T_11:.*]] = tensor.dim %[[T_6]], %[[IDX_2]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_12:.*]] = arith.index_cast %[[T_11]] : index to i64
// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index
// CHECK: %[[T_13:.*]] = tensor.dim %[[T_6]], %[[IDX_3]] : tensor<2x2x3x3xf32>
// CHECK: %[[T_14:.*]] = arith.index_cast %[[T_13]] : index to i64
// CHECK: %[[T_24:.*]] = arith.constant 2 : i64
// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64
// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64
// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64>
// CHECK: %[[T_18:.*]] = "mhlo.dynamic_reshape"(%[[T_6]], %[[T_17]]) : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32>
// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32>
// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64>
// CHECK: %[[T_21:.*]] = "mhlo.dynamic_reshape"(%[[T_19]], %[[T_20]]) : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32>
// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]])
// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32>
// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32>
// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32>
func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
return %3 : !torch.vtensor<[1,4,15,15],f32>
}