mirror of https://github.com/llvm/torch-mlir
[onnx] Support for optional `axis` attribute for `onnx.Pad` (#3635)
The `axis` attribute is optionally available. Added support by computing the pad based on the axis values. --------- Signed-off-by: Rob Suderman <rob.suderman@gmail.com>pull/3655/merge
parent
b3b8e2e96a
commit
6cf139687d
|
@ -2700,15 +2700,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Value data, pads, axes;
|
||||
std::string mode;
|
||||
|
||||
// TODO: The `axes` parameter is not supported yet.
|
||||
if (!binder.tensorOperandAtIndex(axes, 3)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "The axes parameter is not supported yet");
|
||||
}
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
||||
return failure();
|
||||
|
||||
(void)binder.tensorOperandAtIndex(axes, 3);
|
||||
|
||||
bool cstMode = (mode == "constant");
|
||||
|
||||
// get input rank
|
||||
|
@ -2822,6 +2820,90 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (!cstMode)
|
||||
constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
||||
llvm::SmallVector<Value> begins;
|
||||
llvm::SmallVector<Value> ends;
|
||||
for (uint32_t i = 0; i < padsSize / 2; ++i)
|
||||
begins.push_back(padsTensorValue[i]);
|
||||
for (uint32_t i = padsSize / 2; i < padsSize; ++i)
|
||||
ends.push_back(padsTensorValue[i]);
|
||||
|
||||
// If we have the axes we need to compute the appropriate pads:
|
||||
if (axes) {
|
||||
auto axesTy = cast<Torch::ValueTensorType>(axes.getType());
|
||||
assert(axesTy.getSizes().size() == 1);
|
||||
assert(axesTy.getSizes()[0] != Torch::kUnknownSize);
|
||||
|
||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||
int64_t rank = dataTensorType.getSizes().size();
|
||||
auto boolTy = rewriter.getType<Torch::BoolType>();
|
||||
auto intTy = rewriter.getType<Torch::IntType>();
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
|
||||
// Extract the values:
|
||||
int64_t numAxes = axesTy.getSizes()[0];
|
||||
Type axesElemType = Torch::ValueTensorType::get(
|
||||
axesTy.getContext(), ArrayRef<int64_t>{},
|
||||
axesTy.getOptionalDtype());
|
||||
llvm::SmallVector<Value> axesExtracted;
|
||||
Value rankV = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rank));
|
||||
for (uint32_t i = 0; i < numAxes; ++i) {
|
||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, axesElemType, axes, constZero, index);
|
||||
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), select);
|
||||
|
||||
Value negAxis = rewriter.create<Torch::AtenLtIntOp>(
|
||||
loc, boolTy, selectInt, constZero);
|
||||
negAxis =
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intTy, negAxis);
|
||||
Value axis = rewriter.create<Torch::AtenMulIntOp>(loc, intTy,
|
||||
negAxis, rankV);
|
||||
axis = rewriter.create<Torch::AtenAddIntOp>(loc, intTy, axis,
|
||||
selectInt);
|
||||
axesExtracted.push_back(axis);
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value> newBegins;
|
||||
llvm::SmallVector<Value> newEnds;
|
||||
|
||||
for (int j = 0; j < rank; ++j) {
|
||||
Value newBegin = constZero;
|
||||
Value newEnd = constZero;
|
||||
Value iv = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(j));
|
||||
|
||||
for (size_t i = 0; i < axesExtracted.size(); ++i) {
|
||||
Value begin = begins[i];
|
||||
Value end = ends[i];
|
||||
|
||||
Value sameAxis = rewriter.create<Torch::AtenEqIntOp>(
|
||||
loc, boolTy, axesExtracted[i], iv);
|
||||
sameAxis =
|
||||
rewriter.create<Torch::AtenIntBoolOp>(loc, intTy, sameAxis);
|
||||
|
||||
begin = rewriter.create<Torch::AtenMulIntOp>(loc, intTy, sameAxis,
|
||||
begin);
|
||||
end = rewriter.create<Torch::AtenMulIntOp>(loc, intTy, sameAxis,
|
||||
end);
|
||||
|
||||
newBegin = rewriter.create<Torch::AtenAddIntOp>(loc, intTy,
|
||||
newBegin, begin);
|
||||
newEnd =
|
||||
rewriter.create<Torch::AtenAddIntOp>(loc, intTy, newEnd, end);
|
||||
}
|
||||
|
||||
newBegins.push_back(newBegin);
|
||||
newEnds.push_back(newEnd);
|
||||
}
|
||||
|
||||
begins = std::move(newBegins);
|
||||
ends = std::move(newEnds);
|
||||
}
|
||||
|
||||
// The torch.pad op expects a different arrangement of padding pairs for
|
||||
// each dimension as compared to the onnx.pad op. Rearrange the pad
|
||||
// tensor as shown below:
|
||||
|
@ -2829,9 +2911,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
// [x1_begin, x2_begin, ..., x1_end, x2_end,...] ->
|
||||
// [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end]
|
||||
SmallVector<Value> padsRearrange;
|
||||
for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) {
|
||||
padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]);
|
||||
padsRearrange.emplace_back(padsTensorValue[i]);
|
||||
for (int32_t i = begins.size() - 1; i >= 0; i--) {
|
||||
padsRearrange.emplace_back(begins[i]);
|
||||
padsRearrange.emplace_back(ends[i]);
|
||||
}
|
||||
|
||||
Value padsSizeList =
|
||||
|
|
|
@ -1008,6 +1008,87 @@ func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor
|
|||
|
||||
// -----
|
||||
|
||||
func.func @test_center_crop_pad_crop_axes_chw_expanded(%arg0: !torch.vtensor<[4,5],f32>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} {
|
||||
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[PAD0:.+]] = torch.aten.item %[[SEL]]
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[PAD1:.+]] = torch.aten.item %[[SEL]]
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[PAD2:.+]] = torch.aten.item %[[SEL]]
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[PAD3:.+]] = torch.aten.item %[[SEL]]
|
||||
|
||||
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[RANK:.+]] = torch.constant.int 2
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]]
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]]
|
||||
// CHECK: %[[AXIS0:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]]
|
||||
|
||||
// CHECK: %[[IDX:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]]
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]]
|
||||
// CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]]
|
||||
// CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]]
|
||||
// CHECK: %[[AXIS1:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]]
|
||||
|
||||
|
||||
// CHECK: %[[AX:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]]
|
||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]]
|
||||
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]]
|
||||
// CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]]
|
||||
|
||||
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]]
|
||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]]
|
||||
// CHECK: %[[BEGIN0:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]]
|
||||
// CHECK: %[[END0:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]]
|
||||
|
||||
// CHECK: %[[AX:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]]
|
||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]]
|
||||
// CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]]
|
||||
// CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]]
|
||||
|
||||
// CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]]
|
||||
// CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]]
|
||||
// CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]]
|
||||
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]]
|
||||
// CHECK: %[[BEGIN1:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]]
|
||||
// CHECK: %[[END1:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]]
|
||||
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[BEGIN1]], %[[END1]], %[[BEGIN0]], %[[END0]]
|
||||
// CHECK: %[[MODE:.+]] = torch.constant.str "constant"
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[MODE]], %[[NONE]]
|
||||
%none = torch.constant.none
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[4,5],f32>, !torch.vtensor<[4],si64>, !torch.none, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_pow
|
||||
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
|
|
Loading…
Reference in New Issue