mirror of https://github.com/llvm/torch-mlir
[ONNX] Fix bug in ONNXToTorch PadOp's pads tensor rearrangement (#3485)
Fix the pad tensor rearrangement such that we change the representation from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the dimensions of the pads tensor argument. --------- Co-authored-by: zjgarvey <zjgarvey@gmail.com> Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com>pull/3521/head
parent
ca0e906675
commit
0fe74845da
|
@ -2315,12 +2315,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
|
||||
// The torch.pad op expects a different arrangement of padding pairs for
|
||||
// each dimension as compared to the onnx.pad op. So, rearranging pad
|
||||
// tensor to satisfy torch.pad op semantics.
|
||||
// each dimension as compared to the onnx.pad op. Rearrange the pad
|
||||
// tensor as shown below:
|
||||
//
|
||||
// [x1_begin, x2_begin, ..., x1_end, x2_end,...] ->
|
||||
// [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end]
|
||||
SmallVector<Value> padsRearrange;
|
||||
for (uint32_t i = 0; i < padsSize / 2; i++) {
|
||||
for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) {
|
||||
padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]);
|
||||
padsRearrange.emplace_back(padsTensorValue[i]);
|
||||
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
|
||||
}
|
||||
|
||||
Value padsSizeList =
|
||||
|
|
|
@ -3664,7 +3664,9 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
||||
}
|
||||
|
||||
// If the input and output shapes are the same we can just fold:
|
||||
// If the input and output shapes are the same & step == 1 we can fold:
|
||||
if (!step || step.getValue().getSExtValue() != 1)
|
||||
return nullptr;
|
||||
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||
return nullptr;
|
||||
|
|
|
@ -2216,8 +2216,6 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseLog2IntModule_basic",
|
||||
"ElementwiseFminModule_basic",
|
||||
"ElementwiseFmaxModule_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"FlipNegativeIndexModule_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
|
|
|
@ -854,7 +854,7 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
|||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
||||
|
|
Loading…
Reference in New Issue