mirror of https://github.com/llvm/torch-mlir
[onnx] Fix constant pad for dynamic shape (#2989)
The current padding operation was not functional for dynamic shapes. Updated and enabled tests so that onnx.pad tests pass. Work TBD for reflection padding.pull/2995/head
parent
7b18646def
commit
1964208d19
|
@ -908,7 +908,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value data, pads, axes;
|
||||
std::string mode;
|
||||
|
@ -925,36 +925,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return failure();
|
||||
Location loc = binder.getLoc();
|
||||
|
||||
Value constantValue;
|
||||
if (binder.getNumOperands() >= 3) {
|
||||
if (binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||
llvm::errs() << "failed to bind to index 2\n";
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
||||
|
||||
auto maybeZeroAttr = [&]() -> std::optional<Attribute> {
|
||||
if (dataTensorType.getDtype().isa<IntegerType>()) {
|
||||
return rewriter.getI64IntegerAttr(0);
|
||||
}
|
||||
if (dataTensorType.getDtype().isa<FloatType>()) {
|
||||
return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f);
|
||||
}
|
||||
return std::nullopt;
|
||||
}();
|
||||
|
||||
if (!maybeZeroAttr) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected integer or float data tensor");
|
||||
}
|
||||
|
||||
auto shapedType = dataTensorType.toBuiltinTensor();
|
||||
auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr);
|
||||
constantValue = rewriter.create<Torch::ValueTensorLiteralOp>(
|
||||
loc, dataTensorType, splat);
|
||||
}
|
||||
|
||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||
// tensor.
|
||||
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
||||
|
@ -964,14 +934,48 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
||||
int64_t padsRank = padsShape.size();
|
||||
if (padsRank != 1) {
|
||||
if (padsRank != 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expect 1-D pad tensor");
|
||||
"expect 1-d pad tensor");
|
||||
|
||||
int64_t padsSize = padsShape[0];
|
||||
if (padsSize == Torch::kUnknownSize)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"pad length is unknown");
|
||||
|
||||
Value constantValue;
|
||||
if (binder.getNumOperands() >= 3) {
|
||||
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||
auto constTy =
|
||||
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
|
||||
if (!constTy || !constTy.hasDtype())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "constant ty is unsupport type");
|
||||
|
||||
Type scalarTy = rewriter.getType<Torch::IntType>();
|
||||
if (isa<FloatType>(constTy.getDtype()))
|
||||
scalarTy = rewriter.getType<Torch::FloatType>();
|
||||
constantValue = rewriter.create<Torch::AtenItemOp>(loc, scalarTy,
|
||||
constantValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (!constantValue) {
|
||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
||||
if (dataTensorType.getDtype().isa<IntegerType>())
|
||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
if (dataTensorType.getDtype().isa<FloatType>())
|
||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.0f));
|
||||
|
||||
if (!constantValue)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected integer or float data tensor");
|
||||
}
|
||||
|
||||
// Extract all the values of 1-D pad tensor and create a list of all
|
||||
// these values as torch.pad op expects pad list.
|
||||
int64_t padsSize = padsShape[0];
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<Value> padsTensorValue;
|
||||
|
@ -982,8 +986,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
for (uint32_t i = 0; i < padsSize; ++i) {
|
||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
padsTensorValue.emplace_back(rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, padsElemType, pads, constZero, index));
|
||||
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, padsElemType, pads, constZero, index);
|
||||
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), select);
|
||||
padsTensorValue.push_back(selectInt);
|
||||
}
|
||||
|
||||
// The torch.pad op expects a different arrangement of padding pairs for
|
||||
|
@ -991,43 +998,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
// tensor to satisfy torch.pad op semantics.
|
||||
SmallVector<Value> padsRearrange;
|
||||
for (uint32_t i = 0; i < padsSize / 2; i++) {
|
||||
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]);
|
||||
padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]);
|
||||
padsRearrange.emplace_back(padsTensorValue[i]);
|
||||
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
|
||||
}
|
||||
|
||||
Value padsSizeList =
|
||||
rewriter
|
||||
.create<Torch::PrimTolistOp>(
|
||||
.create<Torch::PrimListConstructOp>(
|
||||
loc,
|
||||
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
padsRearrange)
|
||||
.getResult(0);
|
||||
.getResult();
|
||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||
loc, rewriter.getStringAttr(mode));
|
||||
|
||||
// The constant value is a 0-d tensor, which needs to be converted to a
|
||||
// float scalar as torch.pad op expects a float scalar.
|
||||
auto constValueType =
|
||||
constantValue.getType().cast<Torch::ValueTensorType>();
|
||||
if (!constValueType) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expect non-none constant value");
|
||||
}
|
||||
auto resultTensorType = Torch::ValueTensorType::get(
|
||||
constValueType.getContext(), emptyShape, rewriter.getF64Type());
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value constFloatValue = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
loc, resultTensorType, constantValue,
|
||||
Torch::getDtypeIntValueForType(rewriter, loc,
|
||||
resultTensorType.getOptionalDtype()),
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
Value constScalar = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::FloatType>(), constFloatValue);
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenPadOp>(
|
||||
binder.op, resultType, data, padsSizeList, modeVal, constScalar);
|
||||
binder.op, resultType, data, padsSizeList, modeVal, constantValue);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Pow", 1,
|
||||
|
|
|
@ -45,36 +45,69 @@ public:
|
|||
auto type = self.getType().cast<RankedTensorType>();
|
||||
int64_t rank = type.getRank();
|
||||
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant int pad ranges");
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
if (padRank * 2 != padInts.size())
|
||||
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!primList) {
|
||||
return rewriter.notifyMatchFailure(op, "unable to get pad values");
|
||||
}
|
||||
|
||||
SmallVector<Value> padVals(primList.getOperands());
|
||||
|
||||
uint64_t padRank = padVals.size() / 2;
|
||||
if (padRank * 2 != padVals.size())
|
||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||
if (rank < 0 || padRank > (uint64_t)rank)
|
||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||
|
||||
// Initialize low/high paddings with the dims that should not be padded.
|
||||
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
||||
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
||||
int64_t noPad = rank - padRank;
|
||||
Attribute zero = rewriter.getIndexAttr(0);
|
||||
SmallVector<int64_t> staticLow(noPad, 0);
|
||||
SmallVector<int64_t> staticHigh(noPad, 0);
|
||||
SmallVector<OpFoldResult> lowPad(noPad, zero);
|
||||
SmallVector<OpFoldResult> highPad(noPad, zero);
|
||||
|
||||
auto tc = getTypeConverter();
|
||||
|
||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||
// pairs of low,high.
|
||||
for (uint64_t i = padRank; i > 0; --i) {
|
||||
lowPadding.push_back(padInts[i * 2 - 2]);
|
||||
highPadding.push_back(padInts[i * 2 - 1]);
|
||||
int64_t lowi, highi;
|
||||
Value lowv = padVals[i * 2 - 2];
|
||||
Value highv = padVals[i * 2 - 1];
|
||||
if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) {
|
||||
Type cty = tc->convertType(lowv.getType());
|
||||
lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv);
|
||||
lowv = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||
lowv);
|
||||
lowPad.push_back(lowv);
|
||||
staticLow.push_back(ShapedType::kDynamic);
|
||||
} else {
|
||||
lowPad.push_back(rewriter.getIndexAttr(lowi));
|
||||
staticLow.push_back(lowi);
|
||||
}
|
||||
|
||||
if (!matchPattern(highv, m_TorchConstantInt(&highi))) {
|
||||
Type cty = tc->convertType(highv.getType());
|
||||
highv = tc->materializeTargetConversion(rewriter, loc, cty, highv);
|
||||
highv = rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIndexType(), highv);
|
||||
highPad.push_back(highv);
|
||||
staticHigh.push_back(ShapedType::kDynamic);
|
||||
} else {
|
||||
highPad.push_back(rewriter.getIndexAttr(highi));
|
||||
staticHigh.push_back(highi);
|
||||
}
|
||||
}
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
|
||||
Value castedValue =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||
Value paddedInput = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, self, lowPadding, highPadding, castedValue);
|
||||
|
||||
Type padType = tensor::PadOp::inferResultType(
|
||||
self.getType().cast<RankedTensorType>(), staticLow, staticHigh);
|
||||
Value paddedInput = rewriter.create<tensor::PadOp>(
|
||||
loc, padType, self, lowPad, highPad, castedValue);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1955,12 +1955,6 @@ ONNX_XFAIL_SET = {
|
|||
"OneHotModule_basic",
|
||||
|
||||
# Failure - onnx_lowering: onnx.Pad
|
||||
"ConstantPad2dStaticModule_basic",
|
||||
"ConstantPadNdModule_basic",
|
||||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic",
|
||||
"ReflectionPad1dModule2dInput_Right",
|
||||
"ReflectionPad1dModule2dInput_basic",
|
||||
"ReflectionPad1dModule3dInput_Left",
|
||||
|
|
|
@ -109,5 +109,3 @@ class ReflectionPad2dModuleRight(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
||||
def ReflectionPad2dModule_Right(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 20, 20))
|
||||
|
||||
# ==============================================================================
|
||||
|
|
|
@ -447,23 +447,23 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
|
|||
|
||||
// CHECK-LABEL: func.func @test_pad
|
||||
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.tolist(%[[SELECT_1]], %[[SELECT_3]], %[[SELECT_0]], %[[SELECT_2]]) : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.list<int>
|
||||
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
||||
// CHECK: %[[INT7:.+]] = torch.constant.int 7
|
||||
// CHECK: %[[CONVERT:.+]] = torch.aten.to.dtype %arg2, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[CONVERT]] : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
|
@ -474,13 +474,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
|||
// CHECK-LABEL: @test_pad_optional_constant
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[VAL:.+]] = torch.constant.float 0
|
||||
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[SEVEN:.*]] = torch.constant.int 7
|
||||
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
|
||||
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||
|
|
Loading…
Reference in New Issue