[ONNX] Fix bug in ONNXToTorch PadOp's pads tensor rearrangement (#3485)

Fix the pad tensor rearrangement such that we change the representation
from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end,
...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the
dimensions of the pads tensor argument.

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>
pull/3521/head
Sagar Kulkarni 2024-07-03 16:02:49 -04:00 committed by GitHub
parent ca0e906675
commit 0fe74845da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 8 deletions

View File

@ -2315,12 +2315,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
// The torch.pad op expects a different arrangement of padding pairs for
// each dimension as compared to the onnx.pad op. So, rearranging pad
// tensor to satisfy torch.pad op semantics.
// each dimension as compared to the onnx.pad op. Rearrange the pad
// tensor as shown below:
//
// [x1_begin, x2_begin, ..., x1_end, x2_end,...] ->
// [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end]
SmallVector<Value> padsRearrange;
for (uint32_t i = 0; i < padsSize / 2; i++) {
for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) {
padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]);
padsRearrange.emplace_back(padsTensorValue[i]);
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
}
Value padsSizeList =

View File

@ -3664,7 +3664,9 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
}
// If the input and output shapes are the same we can just fold:
// If the input and output shapes are the same & step == 1 we can fold:
if (!step || step.getValue().getSExtValue() != 1)
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;

View File

@ -2216,8 +2216,6 @@ ONNX_XFAIL_SET = {
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",

View File

@ -854,7 +854,7 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>