Constant pad nd to tosa (#1933)

* implemented lowering torch.aten.constant_pad_nd to tosa
* add constant_pad_nd e2e tests to TOSA_PASS_SET
* add PadModule_basic & PadWithNoneValModule_basic to TOSA_PASS_SET
---------

Co-authored-by: Lisa Liu <lingl@xilinx.com>
pull/1944/head snapshot-20230316.779
lisaliu1 2023-03-15 16:42:15 +01:00 committed by GitHub
parent 2468347376
commit 7d711b9f9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 1 deletions

View File

@ -702,7 +702,13 @@ TOSA_PASS_SET = {
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
"RepeatModule_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"PadModule_basic",
"PadWithNoneValModule_basic"
}
LTC_XFAIL_SET = {

View File

@ -4291,6 +4291,71 @@ public:
}
};
template <>
LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>();
auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank();
// START the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> padInts;
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even");
if (rank < 0 || padRank > (uint64_t)rank)
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
// Initialize low/high paddings with 0 for all the dims.
SmallVector<int64_t> lowPadding(/*Size=*/rank, /*Value=*/0);
SmallVector<int64_t> highPadding(/*Size=*/rank, /*Value=*/0);
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
for (uint64_t i = 0; i < padRank; ++i) {
lowPadding[rank-i-1] = padInts[i * 2];
highPadding[rank-i-1] = padInts[i * 2 + 1];
}
//END the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
llvm::SmallVector<int64_t> translatePadsList;
for (unsigned int i = 0; i < rank; i++) {
translatePadsList.push_back(lowPadding[i]);
translatePadsList.push_back(highPadding[i]);
}
DenseElementsAttr paddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({rank, 2}, rewriter.getI64Type()),
translatePadsList);
Value padsList1 = rewriter.create<mlir::tosa::ConstOp>(
loc, paddingAttr.getType(), paddingAttr);
Value padValue = adaptor.getValue();
Operation *padOp = padValue.getDefiningOp();
padValue = padOp->getOperand(0);
Value padTensor;
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue,
padTensor, selfElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Pad value needs to be a scalar constant for conversion to "
"TOSA pad operation");
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor);
return success();
}
} // namespace
// -----------------------------------------------------------------------------
@ -4514,6 +4579,7 @@ public:
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \