[onnx] Fix constant pad for dynamic shape (#2989)

The current padding operation was not functional for dynamic shapes.
Updated and enabled tests so that onnx.pad tests pass.

Work TBD for reflection padding.
pull/2995/head
Rob Suderman 2024-03-07 13:29:50 -08:00 committed by GitHub
parent 7b18646def
commit 1964208d19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 105 additions and 98 deletions

View File

@ -908,7 +908,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
});
patterns.onOp(
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data, pads, axes;
std::string mode;
@ -925,36 +925,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return failure();
Location loc = binder.getLoc();
Value constantValue;
if (binder.getNumOperands() >= 3) {
if (binder.tensorOperandAtIndex(constantValue, 2)) {
llvm::errs() << "failed to bind to index 2\n";
return failure();
}
} else {
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
auto maybeZeroAttr = [&]() -> std::optional<Attribute> {
if (dataTensorType.getDtype().isa<IntegerType>()) {
return rewriter.getI64IntegerAttr(0);
}
if (dataTensorType.getDtype().isa<FloatType>()) {
return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f);
}
return std::nullopt;
}();
if (!maybeZeroAttr) {
return rewriter.notifyMatchFailure(
binder.op, "expected integer or float data tensor");
}
auto shapedType = dataTensorType.toBuiltinTensor();
auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr);
constantValue = rewriter.create<Torch::ValueTensorLiteralOp>(
loc, dataTensorType, splat);
}
// Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor.
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
@ -964,14 +934,48 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
int64_t padsRank = padsShape.size();
if (padsRank != 1) {
if (padsRank != 1)
return rewriter.notifyMatchFailure(binder.op,
"Expect 1-D pad tensor");
"expect 1-d pad tensor");
int64_t padsSize = padsShape[0];
if (padsSize == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(binder.op,
"pad length is unknown");
Value constantValue;
if (binder.getNumOperands() >= 3) {
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
auto constTy =
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
if (!constTy || !constTy.hasDtype())
return rewriter.notifyMatchFailure(
binder.op, "constant ty is unsupport type");
Type scalarTy = rewriter.getType<Torch::IntType>();
if (isa<FloatType>(constTy.getDtype()))
scalarTy = rewriter.getType<Torch::FloatType>();
constantValue = rewriter.create<Torch::AtenItemOp>(loc, scalarTy,
constantValue);
}
}
if (!constantValue) {
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
if (dataTensorType.getDtype().isa<IntegerType>())
constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
if (dataTensorType.getDtype().isa<FloatType>())
constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f));
if (!constantValue)
return rewriter.notifyMatchFailure(
binder.op, "expected integer or float data tensor");
}
// Extract all the values of 1-D pad tensor and create a list of all
// these values as torch.pad op expects pad list.
int64_t padsSize = padsShape[0];
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
SmallVector<Value> padsTensorValue;
@ -982,8 +986,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
for (uint32_t i = 0; i < padsSize; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
padsTensorValue.emplace_back(rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);
padsTensorValue.push_back(selectInt);
}
// The torch.pad op expects a different arrangement of padding pairs for
@ -991,43 +998,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// tensor to satisfy torch.pad op semantics.
SmallVector<Value> padsRearrange;
for (uint32_t i = 0; i < padsSize / 2; i++) {
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]);
padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]);
padsRearrange.emplace_back(padsTensorValue[i]);
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
}
Value padsSizeList =
rewriter
.create<Torch::PrimTolistOp>(
.create<Torch::PrimListConstructOp>(
loc,
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
padsRearrange)
.getResult(0);
.getResult();
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
loc, rewriter.getStringAttr(mode));
// The constant value is a 0-d tensor, which needs to be converted to a
// float scalar as torch.pad op expects a float scalar.
auto constValueType =
constantValue.getType().cast<Torch::ValueTensorType>();
if (!constValueType) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non-none constant value");
}
auto resultTensorType = Torch::ValueTensorType::get(
constValueType.getContext(), emptyShape, rewriter.getF64Type());
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value constFloatValue = rewriter.create<Torch::AtenToDtypeOp>(
loc, resultTensorType, constantValue,
Torch::getDtypeIntValueForType(rewriter, loc,
resultTensorType.getOptionalDtype()),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
Value constScalar = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), constFloatValue);
rewriter.replaceOpWithNewOp<Torch::AtenPadOp>(
binder.op, resultType, data, padsSizeList, modeVal, constScalar);
binder.op, resultType, data, padsSizeList, modeVal, constantValue);
return success();
});
patterns.onOp("Pow", 1,

View File

@ -45,36 +45,69 @@ public:
auto type = self.getType().cast<RankedTensorType>();
int64_t rank = type.getRank();
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
SmallVector<int64_t> padInts;
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size())
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
if (!primList) {
return rewriter.notifyMatchFailure(op, "unable to get pad values");
}
SmallVector<Value> padVals(primList.getOperands());
uint64_t padRank = padVals.size() / 2;
if (padRank * 2 != padVals.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even");
if (rank < 0 || padRank > (uint64_t)rank)
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
// Initialize low/high paddings with the dims that should not be padded.
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
int64_t noPad = rank - padRank;
Attribute zero = rewriter.getIndexAttr(0);
SmallVector<int64_t> staticLow(noPad, 0);
SmallVector<int64_t> staticHigh(noPad, 0);
SmallVector<OpFoldResult> lowPad(noPad, zero);
SmallVector<OpFoldResult> highPad(noPad, zero);
auto tc = getTypeConverter();
// Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high.
for (uint64_t i = padRank; i > 0; --i) {
lowPadding.push_back(padInts[i * 2 - 2]);
highPadding.push_back(padInts[i * 2 - 1]);
int64_t lowi, highi;
Value lowv = padVals[i * 2 - 2];
Value highv = padVals[i * 2 - 1];
if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) {
Type cty = tc->convertType(lowv.getType());
lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv);
lowv = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
lowv);
lowPad.push_back(lowv);
staticLow.push_back(ShapedType::kDynamic);
} else {
lowPad.push_back(rewriter.getIndexAttr(lowi));
staticLow.push_back(lowi);
}
if (!matchPattern(highv, m_TorchConstantInt(&highi))) {
Type cty = tc->convertType(highv.getType());
highv = tc->materializeTargetConversion(rewriter, loc, cty, highv);
highv = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), highv);
highPad.push_back(highv);
staticHigh.push_back(ShapedType::kDynamic);
} else {
highPad.push_back(rewriter.getIndexAttr(highi));
staticHigh.push_back(highi);
}
}
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
Value paddedInput = torch_to_linalg::getPaddedTensor(
op, rewriter, self, lowPadding, highPadding, castedValue);
Type padType = tensor::PadOp::inferResultType(
self.getType().cast<RankedTensorType>(), staticLow, staticHigh);
Value paddedInput = rewriter.create<tensor::PadOp>(
loc, padType, self, lowPad, highPad, castedValue);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
return success();
}

View File

@ -1955,12 +1955,6 @@ ONNX_XFAIL_SET = {
"OneHotModule_basic",
# Failure - onnx_lowering: onnx.Pad
"ConstantPad2dStaticModule_basic",
"ConstantPadNdModule_basic",
"ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic",
"PadModule_basic",
"PadWithNoneValModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
"ReflectionPad1dModule3dInput_Left",

View File

@ -109,5 +109,3 @@ class ReflectionPad2dModuleRight(torch.nn.Module):
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
def ReflectionPad2dModule_Right(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 20, 20))
# ==============================================================================

View File

@ -447,23 +447,23 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
// CHECK-LABEL: func.func @test_pad
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[LIST:.+]] = torch.prim.tolist(%[[SELECT_1]], %[[SELECT_3]], %[[SELECT_0]], %[[SELECT_2]]) : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.list<int>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[INT7:.+]] = torch.constant.int 7
// CHECK: %[[CONVERT:.+]] = torch.aten.to.dtype %arg2, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[CONVERT]] : !torch.vtensor<[],f64> -> !torch.float
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
@ -474,13 +474,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// CHECK-LABEL: @test_pad_optional_constant
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.float 0
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[SEVEN:.*]] = torch.constant.int 7
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>