mirror of https://github.com/llvm/torch-mlir
Constant pad nd to tosa (#1933)
* implemented lowering torch.aten.constant_pad_nd to tosa * add constant_pad_nd e2e tests to TOSA_PASS_SET * add PadModule_basic & PadWithNoneValModule_basic to TOSA_PASS_SET --------- Co-authored-by: Lisa Liu <lingl@xilinx.com>pull/1944/head snapshot-20230316.779
parent
2468347376
commit
7d711b9f9f
|
@ -702,7 +702,13 @@ TOSA_PASS_SET = {
|
|||
"FullLikeModuleInt2DStatic_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"FullModuleFloat2D_basic",
|
||||
"RepeatModule_basic"
|
||||
"RepeatModule_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -4291,6 +4291,71 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfElemTy = selfTy.getElementType();
|
||||
int64_t rank = selfTy.getRank();
|
||||
|
||||
// START the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant int pad ranges");
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
if (padRank * 2 != padInts.size())
|
||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||
if (rank < 0 || padRank > (uint64_t)rank)
|
||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||
|
||||
// Initialize low/high paddings with 0 for all the dims.
|
||||
SmallVector<int64_t> lowPadding(/*Size=*/rank, /*Value=*/0);
|
||||
SmallVector<int64_t> highPadding(/*Size=*/rank, /*Value=*/0);
|
||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||
// pairs of low,high.
|
||||
for (uint64_t i = 0; i < padRank; ++i) {
|
||||
lowPadding[rank-i-1] = padInts[i * 2];
|
||||
highPadding[rank-i-1] = padInts[i * 2 + 1];
|
||||
}
|
||||
//END the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
||||
|
||||
llvm::SmallVector<int64_t> translatePadsList;
|
||||
|
||||
for (unsigned int i = 0; i < rank; i++) {
|
||||
translatePadsList.push_back(lowPadding[i]);
|
||||
translatePadsList.push_back(highPadding[i]);
|
||||
}
|
||||
|
||||
DenseElementsAttr paddingAttr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({rank, 2}, rewriter.getI64Type()),
|
||||
translatePadsList);
|
||||
|
||||
Value padsList1 = rewriter.create<mlir::tosa::ConstOp>(
|
||||
loc, paddingAttr.getType(), paddingAttr);
|
||||
|
||||
Value padValue = adaptor.getValue();
|
||||
Operation *padOp = padValue.getDefiningOp();
|
||||
padValue = padOp->getOperand(0);
|
||||
|
||||
Value padTensor;
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue,
|
||||
padTensor, selfElemTy, {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Pad value needs to be a scalar constant for conversion to "
|
||||
"TOSA pad operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -4514,6 +4579,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
Loading…
Reference in New Issue