[stablehlo] Support ConstantPadNdOp in stablehlo (#3211)

as title
pull/3222/head
Xinyu Yang 2024-04-24 14:15:11 +08:00 committed by GitHub
parent dc470e65c8
commit e18bf42d0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 6 deletions

View File

@ -1632,6 +1632,46 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
AtenConstantPadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().cast<RankedTensorType>();
auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank();
SmallVector<int64_t> padInts;
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even");
if (rank < 0 || padRank > (uint64_t)rank)
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
// Initialize low/high paddings with 0 for all the dims.
SmallVector<int64_t> lowPadding(/*Size=*/rank, /*Value=*/0);
SmallVector<int64_t> highPadding(/*Size=*/rank, /*Value=*/0);
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
for (uint64_t i = 0; i < padRank; ++i) {
lowPadding[rank - i - 1] = padInts[i * 2];
highPadding[rank - i - 1] = padInts[i * 2 + 1];
}
Value constantValue = hlo::scalarToStablehloTensor(
rewriter, op, adaptor.getValue(), selfElemTy);
SmallVector<int64_t> interiorPadding(rank, 0);
rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
op, self, constantValue, lowPadding, highPadding, interiorPadding);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
AtenGeluBackwardOp op, OpAdaptor adaptor,
@ -2070,6 +2110,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);

View File

@ -605,10 +605,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"BroadcastDynamicDimModule_basic",
"CeilFloatModule_basic",
"ConstantBoolParameterModule_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
@ -754,8 +750,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"NumToTensorIntModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"PadModule_basic",
"PadWithNoneValModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
@ -982,6 +976,10 @@ STABLEHLO_PASS_SET = {
"Convolution2DStaticModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CosineSimilarityStaticModule_basic",
"CumsumInputDtypeInt32Module_basic",
@ -1209,6 +1207,8 @@ STABLEHLO_PASS_SET = {
"OnesModuleFalsePinMemory_basic",
"OnesModuleFloat_basic",
"OnesModuleInt_basic",
"PadModule_basic",
"PadWithNoneValModule_basic",
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",