mirror of https://github.com/llvm/torch-mlir
parent
dc470e65c8
commit
e18bf42d0e
|
@ -1632,6 +1632,46 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||
AtenConstantPadNdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||||
auto selfElemTy = selfTy.getElementType();
|
||||
int64_t rank = selfTy.getRank();
|
||||
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int pad ranges");
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
if (padRank * 2 != padInts.size())
|
||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||
if (rank < 0 || padRank > (uint64_t)rank)
|
||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||
|
||||
// Initialize low/high paddings with 0 for all the dims.
|
||||
SmallVector<int64_t> lowPadding(/*Size=*/rank, /*Value=*/0);
|
||||
SmallVector<int64_t> highPadding(/*Size=*/rank, /*Value=*/0);
|
||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||
// pairs of low,high.
|
||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||
// pairs of low,high.
|
||||
for (uint64_t i = 0; i < padRank; ++i) {
|
||||
lowPadding[rank - i - 1] = padInts[i * 2];
|
||||
highPadding[rank - i - 1] = padInts[i * 2 + 1];
|
||||
}
|
||||
|
||||
Value constantValue = hlo::scalarToStablehloTensor(
|
||||
rewriter, op, adaptor.getValue(), selfElemTy);
|
||||
|
||||
SmallVector<int64_t> interiorPadding(rank, 0);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::PadOp>(
|
||||
op, self, constantValue, lowPadding, highPadding, interiorPadding);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||
AtenGeluBackwardOp op, OpAdaptor adaptor,
|
||||
|
@ -2070,6 +2110,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenScalarImplicitOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
|
|
|
@ -605,10 +605,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"BroadcastDynamicDimModule_basic",
|
||||
"CeilFloatModule_basic",
|
||||
"ConstantBoolParameterModule_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"ContainsIntList_False",
|
||||
"ContainsIntList_True",
|
||||
"Conv2dQInt8Module_basic",
|
||||
|
@ -754,8 +750,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||
|
@ -982,6 +976,10 @@ STABLEHLO_PASS_SET = {
|
|||
"Convolution2DStaticModule_basic",
|
||||
"ConvolutionBackwardModule2DStatic_basic",
|
||||
"ConvolutionModule2DTransposeStridedStatic_basic",
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"CosineSimilarityStaticBroadcastModule_basic",
|
||||
"CosineSimilarityStaticModule_basic",
|
||||
"CumsumInputDtypeInt32Module_basic",
|
||||
|
@ -1209,6 +1207,8 @@ STABLEHLO_PASS_SET = {
|
|||
"OnesModuleFalsePinMemory_basic",
|
||||
"OnesModuleFloat_basic",
|
||||
"OnesModuleInt_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
|
|
Loading…
Reference in New Issue