mirror of https://github.com/llvm/torch-mlir
[TOSA] Add reflection and replication pad lowering (#3874)
- Add Torch to TOSA legalization for the following ops: + aten.reflection_pad1d + aten.reflection_pad2d + aten.replication_pad2d - Update xfail sets with new e2e results - Add new LIT tests to basic.mlir Change-Id: I1689d1778d8e472c3317aca1e2425ef8774a07fa Signed-off-by: Justin Ngo <justin.ngo@arm.com>pull/3732/merge
parent
0a607a410d
commit
95f77817b9
|
@ -7194,6 +7194,432 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.reflection_pad1d
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenReflectionPad1dOp>::matchAndRewrite(
|
||||||
|
AtenReflectionPad1dOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> paddingList;
|
||||||
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Non-const padding lists are not supported");
|
||||||
|
|
||||||
|
int64_t paddingLeft = paddingList[0];
|
||||||
|
int64_t paddingRight = paddingList[1];
|
||||||
|
|
||||||
|
if (paddingLeft >= selfShape[selfRank - 1] ||
|
||||||
|
paddingRight >= selfShape[selfRank - 1])
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Padding should be less than input boundary size");
|
||||||
|
|
||||||
|
// Identity case
|
||||||
|
if (paddingLeft == 0 && paddingRight == 0) {
|
||||||
|
rewriter.replaceOp(op, self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> resultTensors;
|
||||||
|
|
||||||
|
// Use tosa.slice and tosa.reverse to get the reflection pads based on the
|
||||||
|
// padding size
|
||||||
|
if (paddingLeft > 0) {
|
||||||
|
SmallVector<int64_t> leftStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> leftSizeSlice(selfShape);
|
||||||
|
|
||||||
|
leftStartSlice[selfRank - 1] = 1;
|
||||||
|
leftSizeSlice[selfRank - 1] = paddingLeft;
|
||||||
|
|
||||||
|
SmallVector<int64_t> leftPadShape(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
leftPadShape.push_back(paddingLeft);
|
||||||
|
|
||||||
|
auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), leftPadType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
|
||||||
|
|
||||||
|
auto leftPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), leftPadType, leftPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 1));
|
||||||
|
|
||||||
|
resultTensors.push_back(leftPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
resultTensors.push_back(self);
|
||||||
|
|
||||||
|
if (paddingRight > 0) {
|
||||||
|
SmallVector<int64_t> rightStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> rightSizeSlice(selfShape);
|
||||||
|
|
||||||
|
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1;
|
||||||
|
rightSizeSlice[selfRank - 1] = paddingRight;
|
||||||
|
|
||||||
|
SmallVector<int64_t> rightPadShape(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
rightPadShape.push_back(paddingRight);
|
||||||
|
|
||||||
|
auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), rightPadType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
|
||||||
|
|
||||||
|
auto rightPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), rightPadType, rightPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 1));
|
||||||
|
|
||||||
|
resultTensors.push_back(rightPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
|
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.reflection_pad2d
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenReflectionPad2dOp>::matchAndRewrite(
|
||||||
|
AtenReflectionPad2dOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultShape = resultType.getShape();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> paddingList;
|
||||||
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Non-const padding lists are not supported");
|
||||||
|
|
||||||
|
int64_t paddingLeft = paddingList[0];
|
||||||
|
int64_t paddingRight = paddingList[1];
|
||||||
|
int64_t paddingTop = paddingList[2];
|
||||||
|
int64_t paddingBottom = paddingList[3];
|
||||||
|
|
||||||
|
if (paddingLeft >= selfShape[selfRank - 1] ||
|
||||||
|
paddingRight >= selfShape[selfRank - 1] ||
|
||||||
|
paddingTop >= selfShape[selfRank - 2] ||
|
||||||
|
paddingBottom >= selfShape[selfRank - 2])
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Padding must be less than the corresponding input dimension");
|
||||||
|
|
||||||
|
// Identity case
|
||||||
|
if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 &&
|
||||||
|
paddingBottom == 0) {
|
||||||
|
rewriter.replaceOp(op, self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use tosa.slice and tosa.reverse to get the reflection pads based on the
|
||||||
|
// padding size
|
||||||
|
SmallVector<Value> sideTensors;
|
||||||
|
|
||||||
|
if (paddingLeft > 0) {
|
||||||
|
SmallVector<int64_t> leftStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> leftSizeSlice(selfShape);
|
||||||
|
|
||||||
|
leftStartSlice[selfRank - 1] = 1;
|
||||||
|
leftSizeSlice[selfRank - 1] = paddingLeft;
|
||||||
|
|
||||||
|
SmallVector<int64_t> leftPadShape(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
leftPadShape.push_back(paddingLeft);
|
||||||
|
|
||||||
|
auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), leftPadType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
|
||||||
|
|
||||||
|
auto leftPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), leftPadType, leftPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 1));
|
||||||
|
|
||||||
|
sideTensors.push_back(leftPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
sideTensors.push_back(self);
|
||||||
|
|
||||||
|
if (paddingRight > 0) {
|
||||||
|
SmallVector<int64_t> rightStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> rightSizeSlice(selfShape);
|
||||||
|
|
||||||
|
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1;
|
||||||
|
rightSizeSlice[selfRank - 1] = paddingRight;
|
||||||
|
|
||||||
|
SmallVector<int64_t> rightPadShape(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
rightPadShape.push_back(paddingRight);
|
||||||
|
|
||||||
|
auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), rightPadType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
|
||||||
|
|
||||||
|
auto rightPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), rightPadType, rightPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 1));
|
||||||
|
|
||||||
|
sideTensors.push_back(rightPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> selfSidePaddedShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
selfSidePaddedShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto selfSidePadded = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
|
rewriter, op->getLoc(),
|
||||||
|
RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors,
|
||||||
|
selfRank - 1);
|
||||||
|
|
||||||
|
SmallVector<Value> resultTensors;
|
||||||
|
|
||||||
|
if (paddingTop > 0) {
|
||||||
|
SmallVector<int64_t> topStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> topSizeSlice(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
topSizeSlice.push_back(resultShape.back());
|
||||||
|
|
||||||
|
topStartSlice[selfRank - 2] = 1;
|
||||||
|
topSizeSlice[selfRank - 2] = paddingTop;
|
||||||
|
|
||||||
|
SmallVector<int64_t> topPadShape(selfShape.begin(), selfShape.end() - 2);
|
||||||
|
topPadShape.push_back(paddingTop);
|
||||||
|
topPadShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto topPadType = RankedTensorType::get(topPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto topPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), topPadType, selfSidePadded,
|
||||||
|
rewriter.getDenseI64ArrayAttr(topStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(topSizeSlice));
|
||||||
|
|
||||||
|
auto topPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), topPadType, topPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 2));
|
||||||
|
|
||||||
|
resultTensors.push_back(topPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
resultTensors.push_back(selfSidePadded.getResult());
|
||||||
|
|
||||||
|
if (paddingBottom > 0) {
|
||||||
|
SmallVector<int64_t> bottomStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> bottomSizeSlice(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
bottomSizeSlice.push_back(resultShape.back());
|
||||||
|
|
||||||
|
bottomStartSlice[selfRank - 2] =
|
||||||
|
selfShape[selfRank - 2] - paddingBottom - 1;
|
||||||
|
bottomSizeSlice[selfRank - 2] = paddingBottom;
|
||||||
|
|
||||||
|
SmallVector<int64_t> bottomPadShape(selfShape.begin(), selfShape.end() - 2);
|
||||||
|
bottomPadShape.push_back(paddingBottom);
|
||||||
|
bottomPadShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy);
|
||||||
|
|
||||||
|
auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), bottomPadType, selfSidePadded,
|
||||||
|
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
|
||||||
|
|
||||||
|
auto bottomPad = rewriter.create<tosa::ReverseOp>(
|
||||||
|
op->getLoc(), bottomPadType, bottomPadSlice.getResult(),
|
||||||
|
static_cast<int32_t>(selfRank - 2));
|
||||||
|
|
||||||
|
resultTensors.push_back(bottomPad.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
|
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legalization for aten.replication_pad2d
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenReplicationPad2dOp>::matchAndRewrite(
|
||||||
|
AtenReplicationPad2dOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
|
auto selfType = dyn_cast<TensorType>(self.getType());
|
||||||
|
if (!selfType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
auto selfShape = selfType.getShape();
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
|
||||||
|
auto resultType =
|
||||||
|
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
auto resultShape = resultType.getShape();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> paddingList;
|
||||||
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Non-const padding lists are not supported");
|
||||||
|
|
||||||
|
int64_t paddingLeft = paddingList[0];
|
||||||
|
int64_t paddingRight = paddingList[1];
|
||||||
|
int64_t paddingTop = paddingList[2];
|
||||||
|
int64_t paddingBottom = paddingList[3];
|
||||||
|
|
||||||
|
// Identity case
|
||||||
|
if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 &&
|
||||||
|
paddingBottom == 0) {
|
||||||
|
rewriter.replaceOp(op, self);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use tosa.slice to get the reflection pads based on the padding size
|
||||||
|
SmallVector<Value> sideTensors;
|
||||||
|
|
||||||
|
if (paddingLeft > 0) {
|
||||||
|
SmallVector<int64_t> leftStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> leftSizeSlice(selfShape);
|
||||||
|
|
||||||
|
leftStartSlice[selfRank - 1] = 0;
|
||||||
|
leftSizeSlice[selfRank - 1] = 1;
|
||||||
|
|
||||||
|
SmallVector<int64_t> leftPadSliceShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
leftPadSliceShape.push_back(1);
|
||||||
|
|
||||||
|
auto leftPadSliceType =
|
||||||
|
RankedTensorType::get(leftPadSliceShape, selfElemTy);
|
||||||
|
|
||||||
|
auto leftPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), leftPadSliceType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(leftSizeSlice));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < paddingLeft; i++)
|
||||||
|
sideTensors.push_back(leftPadSlice.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
sideTensors.push_back(self);
|
||||||
|
|
||||||
|
if (paddingRight > 0) {
|
||||||
|
SmallVector<int64_t> rightStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> rightSizeSlice(selfShape);
|
||||||
|
|
||||||
|
rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - 1;
|
||||||
|
rightSizeSlice[selfRank - 1] = 1;
|
||||||
|
|
||||||
|
SmallVector<int64_t> rightPadSliceShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
rightPadSliceShape.push_back(1);
|
||||||
|
|
||||||
|
auto rightPadSliceType =
|
||||||
|
RankedTensorType::get(rightPadSliceShape, selfElemTy);
|
||||||
|
|
||||||
|
auto rightPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), rightPadSliceType, self,
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(rightSizeSlice));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < paddingRight; i++)
|
||||||
|
sideTensors.push_back(rightPadSlice.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> selfSidePaddedShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
selfSidePaddedShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto selfSidePadded = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
|
rewriter, op->getLoc(),
|
||||||
|
RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors,
|
||||||
|
selfRank - 1);
|
||||||
|
|
||||||
|
SmallVector<Value> resultTensors;
|
||||||
|
|
||||||
|
if (paddingTop > 0) {
|
||||||
|
SmallVector<int64_t> topStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> topSizeSlice(selfShape.begin(), selfShape.end() - 1);
|
||||||
|
topSizeSlice.push_back(resultShape.back());
|
||||||
|
|
||||||
|
topStartSlice[selfRank - 2] = 0;
|
||||||
|
topSizeSlice[selfRank - 2] = 1;
|
||||||
|
|
||||||
|
SmallVector<int64_t> topPadSliceShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 2);
|
||||||
|
topPadSliceShape.push_back(1);
|
||||||
|
topPadSliceShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy);
|
||||||
|
|
||||||
|
auto topPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), topPadSliceType, selfSidePadded,
|
||||||
|
rewriter.getDenseI64ArrayAttr(topStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(topSizeSlice));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < paddingTop; i++)
|
||||||
|
resultTensors.push_back(topPadSlice.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
resultTensors.push_back(selfSidePadded.getResult());
|
||||||
|
|
||||||
|
if (paddingBottom > 0) {
|
||||||
|
SmallVector<int64_t> bottomStartSlice(selfRank, 0);
|
||||||
|
SmallVector<int64_t> bottomSizeSlice(selfShape.begin(),
|
||||||
|
selfShape.end() - 1);
|
||||||
|
bottomSizeSlice.push_back(resultShape.back());
|
||||||
|
|
||||||
|
bottomStartSlice[selfRank - 2] = selfShape[selfRank - 2] - 1;
|
||||||
|
bottomSizeSlice[selfRank - 2] = 1;
|
||||||
|
|
||||||
|
SmallVector<int64_t> bottomPadSliceShape(selfShape.begin(),
|
||||||
|
selfShape.end() - 2);
|
||||||
|
bottomPadSliceShape.push_back(1);
|
||||||
|
bottomPadSliceShape.push_back(resultShape.back());
|
||||||
|
|
||||||
|
auto bottomPadSliceType =
|
||||||
|
RankedTensorType::get(bottomPadSliceShape, selfElemTy);
|
||||||
|
|
||||||
|
auto bottomPadSlice = rewriter.create<tosa::SliceOp>(
|
||||||
|
op->getLoc(), bottomPadSliceType, selfSidePadded,
|
||||||
|
rewriter.getDenseI64ArrayAttr(bottomStartSlice),
|
||||||
|
rewriter.getDenseI64ArrayAttr(bottomSizeSlice));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < paddingBottom; i++)
|
||||||
|
resultTensors.push_back(bottomPadSlice.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = tosa::CreateOpAndInfer<tosa::ConcatOp>(
|
||||||
|
rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -7521,6 +7947,9 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
INSERT_ATENOP_PATTERN(AtenAsStridedOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
INSERT_ATENOP_PATTERN(AtenClampTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1736,6 +1736,20 @@ FX_IMPORTER_TOSA_CRASHING_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
|
"ReflectionPad1dModule3dInput_Left",
|
||||||
|
"ReflectionPad1dModule3dInput_basic",
|
||||||
|
"ReflectionPad2dModule_Bottom",
|
||||||
|
"ReflectionPad2dModule_Left",
|
||||||
|
"ReflectionPad2dModule_Right",
|
||||||
|
"ReflectionPad2dModule_Top",
|
||||||
|
"ReflectionPad2dModule_basic",
|
||||||
|
"ReplicationPad2dModule_basic",
|
||||||
|
"ReplicationPad2dModule_bottom0",
|
||||||
|
"ReplicationPad2dModule_left0",
|
||||||
|
"ReplicationPad2dModule_right0",
|
||||||
|
"ReplicationPad2dModule_top0",
|
||||||
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
|
@ -2439,6 +2453,7 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
TOSA_PASS_SET
|
TOSA_PASS_SET
|
||||||
| {
|
| {
|
||||||
### Tests additionally passing in make_fx_tosa
|
### Tests additionally passing in make_fx_tosa
|
||||||
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
"IsInfiniteModule_basic",
|
"IsInfiniteModule_basic",
|
||||||
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
|
||||||
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
|
||||||
|
@ -4163,7 +4178,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ChunkListUnpackDynamic_Module_basic",
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
"ChunkListUnpackUnevenDynamic_Module_basic",
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
"ChunkListUnpackUneven_Module_basic",
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
"ChunkListUnpack_Module_basic",
|
|
||||||
"CollapseAllDimensionsModule_basic",
|
"CollapseAllDimensionsModule_basic",
|
||||||
"CollapseFullDynamicModule_basic",
|
"CollapseFullDynamicModule_basic",
|
||||||
"CollapsePartialDynamicModule_basic",
|
"CollapsePartialDynamicModule_basic",
|
||||||
|
@ -4538,7 +4552,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"MeanDimNoneDimModule_basic",
|
"MeanDimNoneDimModule_basic",
|
||||||
"MeanDtypeModule_basic",
|
"MeanDtypeModule_basic",
|
||||||
"MeanDynamicSizesModule_basic",
|
"MeanDynamicSizesModule_basic",
|
||||||
"MeanModule_basic",
|
|
||||||
"Mlp1LayerModule_basic",
|
"Mlp1LayerModule_basic",
|
||||||
"Mlp2LayerModuleNoBias_basic",
|
"Mlp2LayerModuleNoBias_basic",
|
||||||
"Mlp2LayerModule_basic",
|
"Mlp2LayerModule_basic",
|
||||||
|
@ -4695,27 +4708,9 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"ReduceSumDimIntListDtypeFloatModule_basic",
|
"ReduceSumDimIntListDtypeFloatModule_basic",
|
||||||
"ReduceSumDimIntListDtypeIntModule_basic",
|
"ReduceSumDimIntListDtypeIntModule_basic",
|
||||||
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
"ReduceSumDimIntListElementTypeBoolModule_basic",
|
||||||
"ReduceSumDimIntListEmptyDimModule_basic",
|
|
||||||
"ReduceSumDtypeFloatModule_basic",
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
"ReduceSumDtypeIntModule_basic",
|
"ReduceSumDtypeIntModule_basic",
|
||||||
"ReduceSumElementTypeBoolModule_basic",
|
"ReduceSumElementTypeBoolModule_basic",
|
||||||
"ReduceSumFloatModule_basic",
|
|
||||||
"ReduceSumSignedIntModule_basic",
|
|
||||||
"ReduceSumUnsignedIntModule_basic",
|
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
|
||||||
"ReflectionPad1dModule3dInput_Left",
|
|
||||||
"ReflectionPad1dModule3dInput_basic",
|
|
||||||
"ReflectionPad2dModule_Bottom",
|
|
||||||
"ReflectionPad2dModule_Left",
|
|
||||||
"ReflectionPad2dModule_Right",
|
|
||||||
"ReflectionPad2dModule_Top",
|
|
||||||
"ReflectionPad2dModule_basic",
|
|
||||||
"ReplicationPad2dModule_basic",
|
|
||||||
"ReplicationPad2dModule_bottom0",
|
|
||||||
"ReplicationPad2dModule_left0",
|
|
||||||
"ReplicationPad2dModule_right0",
|
|
||||||
"ReplicationPad2dModule_top0",
|
|
||||||
"ResNet18Module_basic",
|
"ResNet18Module_basic",
|
||||||
"ReshapeAliasCollapseModule_basic",
|
"ReshapeAliasCollapseModule_basic",
|
||||||
"ReshapeAliasExpandModule_basic",
|
"ReshapeAliasExpandModule_basic",
|
||||||
|
@ -4878,10 +4873,6 @@ ONNX_TOSA_XFAIL_SET = {
|
||||||
"TypePromotionDifferentCategoryModule_basic",
|
"TypePromotionDifferentCategoryModule_basic",
|
||||||
"TypePromotionSameCategoryDifferentWidthModule_basic",
|
"TypePromotionSameCategoryDifferentWidthModule_basic",
|
||||||
"TypePromotionZeroRankHigherCategoryModule_basic",
|
"TypePromotionZeroRankHigherCategoryModule_basic",
|
||||||
"UnflattenIntNegativeOneDimStaticModule_basic",
|
|
||||||
"UnflattenIntNegativeOneSizeStaticModule_basic",
|
|
||||||
"UnflattenIntStaticModule_basic",
|
|
||||||
"UnflattenStaticModule_basic",
|
|
||||||
"UniformModule_basic",
|
"UniformModule_basic",
|
||||||
"UniformNoCorrelationModule_basic",
|
"UniformNoCorrelationModule_basic",
|
||||||
"UniformStaticShapeModule_basic",
|
"UniformStaticShapeModule_basic",
|
||||||
|
|
|
@ -2439,3 +2439,83 @@ func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !tor
|
||||||
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
|
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
|
||||||
return %3 : !torch.vtensor<[1,512,10],f32>
|
return %3 : !torch.vtensor<[1,512,10],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 2, 3>, start = array<i64: 0, 0, 1>} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 2, 1>, start = array<i64: 0, 0, 2>} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32>
|
||||||
|
// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.reflection_pad1d %arg0, %0 : !torch.vtensor<[1,2,4],f32>, !torch.list<int> -> !torch.vtensor<[1,2,8],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,2,8],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.reflection_pad2d$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 10
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 20, 10>, start = array<i64: 0, 0, 1>} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 20, 10>, start = array<i64: 0, 0, 9>} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 1, 10, 40>, start = array<i64: 0, 1, 0>} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array<i64: 1, 10, 40>, start = array<i64: 0, 9, 0>} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32>
|
||||||
|
// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> {
|
||||||
|
%int10 = torch.constant.int 10
|
||||||
|
%0 = torch.prim.ListConstruct %int10, %int10, %int10, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,20,20],f32>, !torch.list<int> -> !torch.vtensor<[1,40,40],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,40,40],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,3,3],f32> -> tensor<1x1x3x3xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 4
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 1, 3, 1>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array<i64: 1, 1, 3, 1>, start = array<i64: 0, 0, 0, 2>} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array<i64: 1, 1, 1, 6>, start = array<i64: 0, 0, 0, 0>} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array<i64: 1, 1, 1, 6>, start = array<i64: 0, 0, 2, 0>} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32>
|
||||||
|
// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32>
|
||||||
|
// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> {
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list<int> -> !torch.vtensor<[1,1,10,6],f32>
|
||||||
|
return %1 : !torch.vtensor<[1,1,10,6],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue