diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0ca182d3c..ef50c3bca 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -2700,15 +2700,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value data, pads, axes; std::string mode; - // TODO: The `axes` parameter is not supported yet. - if (!binder.tensorOperandAtIndex(axes, 3)) { - return rewriter.notifyMatchFailure( - binder.op, "The axes parameter is not supported yet"); - } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); + + (void)binder.tensorOperandAtIndex(axes, 3); + bool cstMode = (mode == "constant"); // get input rank @@ -2822,6 +2820,90 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (!cstMode) constantValue = rewriter.create(loc); + llvm::SmallVector begins; + llvm::SmallVector ends; + for (uint32_t i = 0; i < padsSize / 2; ++i) + begins.push_back(padsTensorValue[i]); + for (uint32_t i = padsSize / 2; i < padsSize; ++i) + ends.push_back(padsTensorValue[i]); + + // If we have the axes we need to compute the appropriate pads: + if (axes) { + auto axesTy = cast(axes.getType()); + assert(axesTy.getSizes().size() == 1); + assert(axesTy.getSizes()[0] != Torch::kUnknownSize); + + auto dataTensorType = cast(data.getType()); + int64_t rank = dataTensorType.getSizes().size(); + auto boolTy = rewriter.getType(); + auto intTy = rewriter.getType(); + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + // Extract the values: + int64_t numAxes = axesTy.getSizes()[0]; + Type axesElemType = Torch::ValueTensorType::get( + axesTy.getContext(), ArrayRef{}, + axesTy.getOptionalDtype()); + llvm::SmallVector axesExtracted; + Value rankV = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank)); + for (uint32_t i = 0; i < numAxes; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto select = rewriter.create( + loc, axesElemType, axes, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + + Value negAxis = rewriter.create( + loc, boolTy, selectInt, constZero); + negAxis = + rewriter.create(loc, intTy, negAxis); + Value axis = rewriter.create(loc, intTy, + negAxis, rankV); + axis = rewriter.create(loc, intTy, axis, + selectInt); + axesExtracted.push_back(axis); + } + + llvm::SmallVector newBegins; + llvm::SmallVector newEnds; + + for (int j = 0; j < rank; ++j) { + Value newBegin = constZero; + Value newEnd = constZero; + Value iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(j)); + + for (size_t i = 0; i < axesExtracted.size(); ++i) { + Value begin = begins[i]; + Value end = ends[i]; + + Value sameAxis = rewriter.create( + loc, boolTy, axesExtracted[i], iv); + sameAxis = + rewriter.create(loc, intTy, sameAxis); + + begin = rewriter.create(loc, intTy, sameAxis, + begin); + end = rewriter.create(loc, intTy, sameAxis, + end); + + newBegin = rewriter.create(loc, intTy, + newBegin, begin); + newEnd = + rewriter.create(loc, intTy, newEnd, end); + } + + newBegins.push_back(newBegin); + newEnds.push_back(newEnd); + } + + begins = std::move(newBegins); + ends = std::move(newEnds); + } + // The torch.pad op expects a different arrangement of padding pairs for // each dimension as compared to the onnx.pad op. Rearrange the pad // tensor as shown below: @@ -2829,9 +2911,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // [x1_begin, x2_begin, ..., x1_end, x2_end,...] -> // [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end] SmallVector padsRearrange; - for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) { - padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]); - padsRearrange.emplace_back(padsTensorValue[i]); + for (int32_t i = begins.size() - 1; i >= 0; i--) { + padsRearrange.emplace_back(begins[i]); + padsRearrange.emplace_back(ends[i]); } Value padsSizeList = diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 43ced2e29..2e7b59088 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1008,6 +1008,87 @@ func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor // ----- +func.func @test_center_crop_pad_crop_axes_chw_expanded(%arg0: !torch.vtensor<[4,5],f32>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD0:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 1 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD1:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 2 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD2:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 3 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD3:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 2 + + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]] + // CHECK: %[[AXIS0:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 1 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]] + // CHECK: %[[AXIS1:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]] + + + // CHECK: %[[AX:.+]] = torch.constant.int 0 + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]] + // CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]] + + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]] + // CHECK: %[[BEGIN0:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]] + // CHECK: %[[END0:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]] + + // CHECK: %[[AX:.+]] = torch.constant.int 1 + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]] + // CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]] + + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]] + // CHECK: %[[BEGIN1:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]] + // CHECK: %[[END1:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]] + + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[BEGIN1]], %[[END1]], %[[BEGIN0]], %[[END0]] + // CHECK: %[[MODE:.+]] = torch.constant.str "constant" + // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[MODE]], %[[NONE]] + %none = torch.constant.none + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[4,5],f32>, !torch.vtensor<[4],si64>, !torch.none, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>