[TOSA] Add reflection and replication pad lowering (#3874)

- Add Torch to TOSA legalization for the following ops:
  + aten.reflection_pad1d
  + aten.reflection_pad2d
  + aten.replication_pad2d
- Update xfail sets with new e2e results
- Add new LIT tests to basic.mlir


Change-Id: I1689d1778d8e472c3317aca1e2425ef8774a07fa

Signed-off-by: Justin Ngo <justin.ngo@arm.com>
pull/3732/merge
Justin Ngo 2024-11-15 15:19:09 -08:00 committed by GitHub
parent 0a607a410d
commit 95f77817b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 524 additions and 24 deletions

View File

@ -7194,6 +7194,432 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return success();
}
// Legalization for aten.reflection_pad1d
template <>
LogicalResult ConvertAtenOp<AtenReflectionPad1dOp>::matchAndRewrite(
AtenReflectionPad1dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
SmallVector<int64_t, 2> paddingList;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
return rewriter.notifyMatchFailure(
op, "Non-const padding lists are not supported");
int64_t paddingLeft = paddingList[0];
int64_t paddingRight = paddingList[1];
if (paddingLeft >= selfShape[selfRank - 1] ||
paddingRight >= selfShape[selfRank - 1])
return rewriter.notifyMatchFailure(
op, "Padding should be less than input boundary size");
// Identity case
if (paddingLeft == 0 && paddingRight == 0) {
rewriter.replaceOp(op, self);
return success();
}
SmallVector<Value> resultTensors;
// Use tosa.slice and tosa.reverse to get the reflection pads based on the
// padding size
if (paddingLeft > 0) {
SmallVector<int64_t> leftStartSlice(selfRank, 0);
SmallVector<int64_t> leftSizeSlice(selfShape);
leftStartSlice[selfRank - 1] = 1;
leftSizeSlice[selfRank - 1] = paddingLeft;
SmallVector<int64_t> leftPadShape(selfShape.begin(), selfShape.end() - 1);
leftPadShape.push_back(paddingLeft);
auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy);
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), leftPadType, self,
rewriter.getDenseI64ArrayAttr(leftStartSlice),
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
auto leftPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), leftPadType, leftPadSlice.getResult(),
static_cast<int32_t>(selfRank - 1));
resultTensors.push_back(leftPad.getResult());
}
resultTensors.push_back(self);
if (paddingRight > 0) {
SmallVector<int64_t> rightStartSlice(selfRank, 0);
SmallVector<int64_t> rightSizeSlice(selfShape);
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1;
rightSizeSlice[selfRank - 1] = paddingRight;
SmallVector<int64_t> rightPadShape(selfShape.begin(), selfShape.end() - 1);
rightPadShape.push_back(paddingRight);
auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy);
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), rightPadType, self,
rewriter.getDenseI64ArrayAttr(rightStartSlice),
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
auto rightPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), rightPadType, rightPadSlice.getResult(),
static_cast<int32_t>(selfRank - 1));
resultTensors.push_back(rightPad.getResult());
}
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1);
rewriter.replaceOp(op, result);
return success();
}
// Legalization for aten.reflection_pad2d
template <>
LogicalResult ConvertAtenOp<AtenReflectionPad2dOp>::matchAndRewrite(
AtenReflectionPad2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape();
SmallVector<int64_t, 4> paddingList;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
return rewriter.notifyMatchFailure(
op, "Non-const padding lists are not supported");
int64_t paddingLeft = paddingList[0];
int64_t paddingRight = paddingList[1];
int64_t paddingTop = paddingList[2];
int64_t paddingBottom = paddingList[3];
if (paddingLeft >= selfShape[selfRank - 1] ||
paddingRight >= selfShape[selfRank - 1] ||
paddingTop >= selfShape[selfRank - 2] ||
paddingBottom >= selfShape[selfRank - 2])
return rewriter.notifyMatchFailure(
op, "Padding must be less than the corresponding input dimension");
// Identity case
if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 &&
paddingBottom == 0) {
rewriter.replaceOp(op, self);
return success();
}
// Use tosa.slice and tosa.reverse to get the reflection pads based on the
// padding size
SmallVector<Value> sideTensors;
if (paddingLeft > 0) {
SmallVector<int64_t> leftStartSlice(selfRank, 0);
SmallVector<int64_t> leftSizeSlice(selfShape);
leftStartSlice[selfRank - 1] = 1;
leftSizeSlice[selfRank - 1] = paddingLeft;
SmallVector<int64_t> leftPadShape(selfShape.begin(), selfShape.end() - 1);
leftPadShape.push_back(paddingLeft);
auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy);
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), leftPadType, self,
rewriter.getDenseI64ArrayAttr(leftStartSlice),
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
auto leftPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), leftPadType, leftPadSlice.getResult(),
static_cast<int32_t>(selfRank - 1));
sideTensors.push_back(leftPad.getResult());
}
sideTensors.push_back(self);
if (paddingRight > 0) {
SmallVector<int64_t> rightStartSlice(selfRank, 0);
SmallVector<int64_t> rightSizeSlice(selfShape);
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1;
rightSizeSlice[selfRank - 1] = paddingRight;
SmallVector<int64_t> rightPadShape(selfShape.begin(), selfShape.end() - 1);
rightPadShape.push_back(paddingRight);
auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy);
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), rightPadType, self,
rewriter.getDenseI64ArrayAttr(rightStartSlice),
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
auto rightPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), rightPadType, rightPadSlice.getResult(),
static_cast<int32_t>(selfRank - 1));
sideTensors.push_back(rightPad.getResult());
}
SmallVector<int64_t> selfSidePaddedShape(selfShape.begin(),
selfShape.end() - 1);
selfSidePaddedShape.push_back(resultShape.back());
auto selfSidePadded = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors,
selfRank - 1);
SmallVector<Value> resultTensors;
if (paddingTop > 0) {
SmallVector<int64_t> topStartSlice(selfRank, 0);
SmallVector<int64_t> topSizeSlice(selfShape.begin(), selfShape.end() - 1);
topSizeSlice.push_back(resultShape.back());
topStartSlice[selfRank - 2] = 1;
topSizeSlice[selfRank - 2] = paddingTop;
SmallVector<int64_t> topPadShape(selfShape.begin(), selfShape.end() - 2);
topPadShape.push_back(paddingTop);
topPadShape.push_back(resultShape.back());
auto topPadType = RankedTensorType::get(topPadShape, selfElemTy);
auto topPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), topPadType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(topStartSlice),
rewriter.getDenseI64ArrayAttr(topSizeSlice));
auto topPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), topPadType, topPadSlice.getResult(),
static_cast<int32_t>(selfRank - 2));
resultTensors.push_back(topPad.getResult());
}
resultTensors.push_back(selfSidePadded.getResult());
if (paddingBottom > 0) {
SmallVector<int64_t> bottomStartSlice(selfRank, 0);
SmallVector<int64_t> bottomSizeSlice(selfShape.begin(),
selfShape.end() - 1);
bottomSizeSlice.push_back(resultShape.back());
bottomStartSlice[selfRank - 2] =
selfShape[selfRank - 2] - paddingBottom - 1;
bottomSizeSlice[selfRank - 2] = paddingBottom;
SmallVector<int64_t> bottomPadShape(selfShape.begin(), selfShape.end() - 2);
bottomPadShape.push_back(paddingBottom);
bottomPadShape.push_back(resultShape.back());
auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy);
auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), bottomPadType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
auto bottomPad = rewriter.create<tosa::ReverseOp>(
op->getLoc(), bottomPadType, bottomPadSlice.getResult(),
static_cast<int32_t>(selfRank - 2));
resultTensors.push_back(bottomPad.getResult());
}
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2);
rewriter.replaceOp(op, result);
return success();
}
// Legalization for aten.replication_pad2d
template <>
LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
AtenReplicationPad2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();
auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultShape = resultType.getShape();
SmallVector<int64_t, 4> paddingList;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
return rewriter.notifyMatchFailure(
op, "Non-const padding lists are not supported");
int64_t paddingLeft = paddingList[0];
int64_t paddingRight = paddingList[1];
int64_t paddingTop = paddingList[2];
int64_t paddingBottom = paddingList[3];
// Identity case
if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 &&
paddingBottom == 0) {
rewriter.replaceOp(op, self);
return success();
}
// Use tosa.slice to get the reflection pads based on the padding size
SmallVector<Value> sideTensors;
if (paddingLeft > 0) {
SmallVector<int64_t> leftStartSlice(selfRank, 0);
SmallVector<int64_t> leftSizeSlice(selfShape);
leftStartSlice[selfRank - 1] = 0;
leftSizeSlice[selfRank - 1] = 1;
SmallVector<int64_t> leftPadSliceShape(selfShape.begin(),
selfShape.end() - 1);
leftPadSliceShape.push_back(1);
auto leftPadSliceType =
RankedTensorType::get(leftPadSliceShape, selfElemTy);
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), leftPadSliceType, self,
rewriter.getDenseI64ArrayAttr(leftStartSlice),
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
for (int64_t i = 0; i < paddingLeft; i++)
sideTensors.push_back(leftPadSlice.getResult());
}
sideTensors.push_back(self);
if (paddingRight > 0) {
SmallVector<int64_t> rightStartSlice(selfRank, 0);
SmallVector<int64_t> rightSizeSlice(selfShape);
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - 1;
rightSizeSlice[selfRank - 1] = 1;
SmallVector<int64_t> rightPadSliceShape(selfShape.begin(),
selfShape.end() - 1);
rightPadSliceShape.push_back(1);
auto rightPadSliceType =
RankedTensorType::get(rightPadSliceShape, selfElemTy);
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), rightPadSliceType, self,
rewriter.getDenseI64ArrayAttr(rightStartSlice),
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
for (int64_t i = 0; i < paddingRight; i++)
sideTensors.push_back(rightPadSlice.getResult());
}
SmallVector<int64_t> selfSidePaddedShape(selfShape.begin(),
selfShape.end() - 1);
selfSidePaddedShape.push_back(resultShape.back());
auto selfSidePadded = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors,
selfRank - 1);
SmallVector<Value> resultTensors;
if (paddingTop > 0) {
SmallVector<int64_t> topStartSlice(selfRank, 0);
SmallVector<int64_t> topSizeSlice(selfShape.begin(), selfShape.end() - 1);
topSizeSlice.push_back(resultShape.back());
topStartSlice[selfRank - 2] = 0;
topSizeSlice[selfRank - 2] = 1;
SmallVector<int64_t> topPadSliceShape(selfShape.begin(),
selfShape.end() - 2);
topPadSliceShape.push_back(1);
topPadSliceShape.push_back(resultShape.back());
auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy);
auto topPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), topPadSliceType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(topStartSlice),
rewriter.getDenseI64ArrayAttr(topSizeSlice));
for (int64_t i = 0; i < paddingTop; i++)
resultTensors.push_back(topPadSlice.getResult());
}
resultTensors.push_back(selfSidePadded.getResult());
if (paddingBottom > 0) {
SmallVector<int64_t> bottomStartSlice(selfRank, 0);
SmallVector<int64_t> bottomSizeSlice(selfShape.begin(),
selfShape.end() - 1);
bottomSizeSlice.push_back(resultShape.back());
bottomStartSlice[selfRank - 2] = selfShape[selfRank - 2] - 1;
bottomSizeSlice[selfRank - 2] = 1;
SmallVector<int64_t> bottomPadSliceShape(selfShape.begin(),
selfShape.end() - 2);
bottomPadSliceShape.push_back(1);
bottomPadSliceShape.push_back(resultShape.back());
auto bottomPadSliceType =
RankedTensorType::get(bottomPadSliceShape, selfElemTy);
auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
op->getLoc(), bottomPadSliceType, selfSidePadded,
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
for (int64_t i = 0; i < paddingBottom; i++)
resultTensors.push_back(bottomPadSlice.getResult());
}
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2);
rewriter.replaceOp(op, result);
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -7521,6 +7947,9 @@ public:
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp);
INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp);
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1736,6 +1736,20 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
"ReflectionPad1dModule3dInput_basic",
"ReflectionPad2dModule_Bottom",
"ReflectionPad2dModule_Left",
"ReflectionPad2dModule_Right",
"ReflectionPad2dModule_Top",
"ReflectionPad2dModule_basic",
"ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0",
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
@ -2439,6 +2453,7 @@ MAKE_FX_TOSA_PASS_SET = (
TOSA_PASS_SET
| {
### Tests additionally passing in make_fx_tosa
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"IsInfiniteModule_basic",
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
@ -4163,7 +4178,6 @@ ONNX_TOSA_XFAIL_SET = {
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"CollapseAllDimensionsModule_basic",
"CollapseFullDynamicModule_basic",
"CollapsePartialDynamicModule_basic",
@ -4538,7 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
"MeanDimNoneDimModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModuleNoBias_basic",
"Mlp2LayerModule_basic",
@ -4695,27 +4708,9 @@ ONNX_TOSA_XFAIL_SET = {
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",
"ReflectionPad1dModule3dInput_basic",
"ReflectionPad2dModule_Bottom",
"ReflectionPad2dModule_Left",
"ReflectionPad2dModule_Right",
"ReflectionPad2dModule_Top",
"ReflectionPad2dModule_basic",
"ReplicationPad2dModule_basic",
"ReplicationPad2dModule_bottom0",
"ReplicationPad2dModule_left0",
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"ResNet18Module_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
@ -4878,10 +4873,6 @@ ONNX_TOSA_XFAIL_SET = {
"TypePromotionDifferentCategoryModule_basic",
"TypePromotionSameCategoryDifferentWidthModule_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UniformModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",

View File

@ -2439,3 +2439,83 @@ func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !tor
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 2, 3>, start = array<i64: 0, 0, 1>} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32>
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 2, 1>, start = array<i64: 0, 0, 2>} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32>
// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32>
// CHECK: }
func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%0 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.reflection_pad1d %arg0, %0 : !torch.vtensor<[1,2,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,8],f32>
return %1 : !torch.vtensor<[1,2,8],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.reflection_pad2d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 10
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 20, 10>, start = array<i64: 0, 0, 1>} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32>
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32>
// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 20, 10>, start = array<i64: 0, 0, 9>} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32>
// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32>
// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32>
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 1, 10, 40>, start = array<i64: 0, 1, 0>} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32>
// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32>
// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 1, 10, 40>, start = array<i64: 0, 9, 0>} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32>
// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32>
// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32>
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32>
// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32>
// CHECK: }
func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> {
%int10 = torch.constant.int 10
%0 = torch.prim.ListConstruct %int10, %int10, %int10, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,20,20],f32>, !torch.list<int> -> !torch.vtensor<[1,40,40],f32>
return %1 : !torch.vtensor<[1,40,40],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,3,3],f32> -> tensor<1x1x3x3xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
// CHECK: %[[VAL_5:.*]] = torch.constant.int 4
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 1, 3, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32>
// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 1, 3, 1>, start = array<i64: 0, 0, 0, 2>} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32>
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32>
// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array<i64: 1, 1, 1, 6>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32>
// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array<i64: 1, 1, 1, 6>, start = array<i64: 0, 0, 2, 0>} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32>
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32>
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32>
// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32>
// CHECK: }
func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int> -> !torch.vtensor<[1,1,10,6],f32>
return %1 : !torch.vtensor<[1,1,10,6],f32>
}