mirror of https://github.com/llvm/torch-mlir
[onnx] Fix constant pad for dynamic shape (#2989)
The current padding operation was not functional for dynamic shapes. Updated and enabled tests so that onnx.pad tests pass. Work TBD for reflection padding.pull/2995/head
parent
7b18646def
commit
1964208d19
|
@ -908,7 +908,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value data, pads, axes;
|
Value data, pads, axes;
|
||||||
std::string mode;
|
std::string mode;
|
||||||
|
@ -925,36 +925,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = binder.getLoc();
|
Location loc = binder.getLoc();
|
||||||
|
|
||||||
Value constantValue;
|
|
||||||
if (binder.getNumOperands() >= 3) {
|
|
||||||
if (binder.tensorOperandAtIndex(constantValue, 2)) {
|
|
||||||
llvm::errs() << "failed to bind to index 2\n";
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
|
||||||
|
|
||||||
auto maybeZeroAttr = [&]() -> std::optional<Attribute> {
|
|
||||||
if (dataTensorType.getDtype().isa<IntegerType>()) {
|
|
||||||
return rewriter.getI64IntegerAttr(0);
|
|
||||||
}
|
|
||||||
if (dataTensorType.getDtype().isa<FloatType>()) {
|
|
||||||
return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f);
|
|
||||||
}
|
|
||||||
return std::nullopt;
|
|
||||||
}();
|
|
||||||
|
|
||||||
if (!maybeZeroAttr) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
binder.op, "expected integer or float data tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shapedType = dataTensorType.toBuiltinTensor();
|
|
||||||
auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr);
|
|
||||||
constantValue = rewriter.create<Torch::ValueTensorLiteralOp>(
|
|
||||||
loc, dataTensorType, splat);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||||
// tensor.
|
// tensor.
|
||||||
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
||||||
|
@ -964,14 +934,48 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
||||||
int64_t padsRank = padsShape.size();
|
int64_t padsRank = padsShape.size();
|
||||||
if (padsRank != 1) {
|
if (padsRank != 1)
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"Expect 1-D pad tensor");
|
"expect 1-d pad tensor");
|
||||||
|
|
||||||
|
int64_t padsSize = padsShape[0];
|
||||||
|
if (padsSize == Torch::kUnknownSize)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"pad length is unknown");
|
||||||
|
|
||||||
|
Value constantValue;
|
||||||
|
if (binder.getNumOperands() >= 3) {
|
||||||
|
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||||
|
auto constTy =
|
||||||
|
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
|
||||||
|
if (!constTy || !constTy.hasDtype())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "constant ty is unsupport type");
|
||||||
|
|
||||||
|
Type scalarTy = rewriter.getType<Torch::IntType>();
|
||||||
|
if (isa<FloatType>(constTy.getDtype()))
|
||||||
|
scalarTy = rewriter.getType<Torch::FloatType>();
|
||||||
|
constantValue = rewriter.create<Torch::AtenItemOp>(loc, scalarTy,
|
||||||
|
constantValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!constantValue) {
|
||||||
|
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
||||||
|
if (dataTensorType.getDtype().isa<IntegerType>())
|
||||||
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
if (dataTensorType.getDtype().isa<FloatType>())
|
||||||
|
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(0.0f));
|
||||||
|
|
||||||
|
if (!constantValue)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "expected integer or float data tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract all the values of 1-D pad tensor and create a list of all
|
// Extract all the values of 1-D pad tensor and create a list of all
|
||||||
// these values as torch.pad op expects pad list.
|
// these values as torch.pad op expects pad list.
|
||||||
int64_t padsSize = padsShape[0];
|
|
||||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
SmallVector<Value> padsTensorValue;
|
SmallVector<Value> padsTensorValue;
|
||||||
|
@ -982,8 +986,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
for (uint32_t i = 0; i < padsSize; ++i) {
|
for (uint32_t i = 0; i < padsSize; ++i) {
|
||||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(i));
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
padsTensorValue.emplace_back(rewriter.create<Torch::AtenSelectIntOp>(
|
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
loc, padsElemType, pads, constZero, index));
|
loc, padsElemType, pads, constZero, index);
|
||||||
|
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
loc, rewriter.getType<Torch::IntType>(), select);
|
||||||
|
padsTensorValue.push_back(selectInt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The torch.pad op expects a different arrangement of padding pairs for
|
// The torch.pad op expects a different arrangement of padding pairs for
|
||||||
|
@ -991,43 +998,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
// tensor to satisfy torch.pad op semantics.
|
// tensor to satisfy torch.pad op semantics.
|
||||||
SmallVector<Value> padsRearrange;
|
SmallVector<Value> padsRearrange;
|
||||||
for (uint32_t i = 0; i < padsSize / 2; i++) {
|
for (uint32_t i = 0; i < padsSize / 2; i++) {
|
||||||
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]);
|
padsRearrange.emplace_back(padsTensorValue[i]);
|
||||||
padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]);
|
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value padsSizeList =
|
Value padsSizeList =
|
||||||
rewriter
|
rewriter
|
||||||
.create<Torch::PrimTolistOp>(
|
.create<Torch::PrimListConstructOp>(
|
||||||
loc,
|
loc,
|
||||||
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||||
padsRearrange)
|
padsRearrange)
|
||||||
.getResult(0);
|
.getResult();
|
||||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||||
loc, rewriter.getStringAttr(mode));
|
loc, rewriter.getStringAttr(mode));
|
||||||
|
|
||||||
// The constant value is a 0-d tensor, which needs to be converted to a
|
|
||||||
// float scalar as torch.pad op expects a float scalar.
|
|
||||||
auto constValueType =
|
|
||||||
constantValue.getType().cast<Torch::ValueTensorType>();
|
|
||||||
if (!constValueType) {
|
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
|
||||||
"Expect non-none constant value");
|
|
||||||
}
|
|
||||||
auto resultTensorType = Torch::ValueTensorType::get(
|
|
||||||
constValueType.getContext(), emptyShape, rewriter.getF64Type());
|
|
||||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
|
||||||
Value constFloatValue = rewriter.create<Torch::AtenToDtypeOp>(
|
|
||||||
loc, resultTensorType, constantValue,
|
|
||||||
Torch::getDtypeIntValueForType(rewriter, loc,
|
|
||||||
resultTensorType.getOptionalDtype()),
|
|
||||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
|
||||||
/*memory_format=*/none);
|
|
||||||
Value constScalar = rewriter.create<Torch::AtenItemOp>(
|
|
||||||
loc, rewriter.getType<Torch::FloatType>(), constFloatValue);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenPadOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenPadOp>(
|
||||||
binder.op, resultType, data, padsSizeList, modeVal, constScalar);
|
binder.op, resultType, data, padsSizeList, modeVal, constantValue);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("Pow", 1,
|
patterns.onOp("Pow", 1,
|
||||||
|
|
|
@ -45,36 +45,69 @@ public:
|
||||||
auto type = self.getType().cast<RankedTensorType>();
|
auto type = self.getType().cast<RankedTensorType>();
|
||||||
int64_t rank = type.getRank();
|
int64_t rank = type.getRank();
|
||||||
|
|
||||||
// Pattern match against the op's original operands, because otherwise we
|
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
// will get the lowered version of the operands which is harder to pattern
|
if (!primList) {
|
||||||
// match.
|
return rewriter.notifyMatchFailure(op, "unable to get pad values");
|
||||||
SmallVector<int64_t> padInts;
|
}
|
||||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
SmallVector<Value> padVals(primList.getOperands());
|
||||||
op, "only support constant int pad ranges");
|
|
||||||
uint64_t padRank = padInts.size() / 2;
|
uint64_t padRank = padVals.size() / 2;
|
||||||
if (padRank * 2 != padInts.size())
|
if (padRank * 2 != padVals.size())
|
||||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||||
if (rank < 0 || padRank > (uint64_t)rank)
|
if (rank < 0 || padRank > (uint64_t)rank)
|
||||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||||
|
|
||||||
// Initialize low/high paddings with the dims that should not be padded.
|
// Initialize low/high paddings with the dims that should not be padded.
|
||||||
SmallVector<int64_t, 4> lowPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
int64_t noPad = rank - padRank;
|
||||||
SmallVector<int64_t, 4> highPadding(/*Size=*/rank - padRank, /*Value=*/0);
|
Attribute zero = rewriter.getIndexAttr(0);
|
||||||
|
SmallVector<int64_t> staticLow(noPad, 0);
|
||||||
|
SmallVector<int64_t> staticHigh(noPad, 0);
|
||||||
|
SmallVector<OpFoldResult> lowPad(noPad, zero);
|
||||||
|
SmallVector<OpFoldResult> highPad(noPad, zero);
|
||||||
|
|
||||||
|
auto tc = getTypeConverter();
|
||||||
|
|
||||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||||
// pairs of low,high.
|
// pairs of low,high.
|
||||||
for (uint64_t i = padRank; i > 0; --i) {
|
for (uint64_t i = padRank; i > 0; --i) {
|
||||||
lowPadding.push_back(padInts[i * 2 - 2]);
|
int64_t lowi, highi;
|
||||||
highPadding.push_back(padInts[i * 2 - 1]);
|
Value lowv = padVals[i * 2 - 2];
|
||||||
|
Value highv = padVals[i * 2 - 1];
|
||||||
|
if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) {
|
||||||
|
Type cty = tc->convertType(lowv.getType());
|
||||||
|
lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv);
|
||||||
|
lowv = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
|
||||||
|
lowv);
|
||||||
|
lowPad.push_back(lowv);
|
||||||
|
staticLow.push_back(ShapedType::kDynamic);
|
||||||
|
} else {
|
||||||
|
lowPad.push_back(rewriter.getIndexAttr(lowi));
|
||||||
|
staticLow.push_back(lowi);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matchPattern(highv, m_TorchConstantInt(&highi))) {
|
||||||
|
Type cty = tc->convertType(highv.getType());
|
||||||
|
highv = tc->materializeTargetConversion(rewriter, loc, cty, highv);
|
||||||
|
highv = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, rewriter.getIndexType(), highv);
|
||||||
|
highPad.push_back(highv);
|
||||||
|
staticHigh.push_back(ShapedType::kDynamic);
|
||||||
|
} else {
|
||||||
|
highPad.push_back(rewriter.getIndexAttr(highi));
|
||||||
|
staticHigh.push_back(highi);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
|
Type elementType = newResultType.cast<RankedTensorType>().getElementType();
|
||||||
Value castedValue =
|
Value castedValue =
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||||
Value paddedInput = torch_to_linalg::getPaddedTensor(
|
|
||||||
op, rewriter, self, lowPadding, highPadding, castedValue);
|
|
||||||
|
|
||||||
|
Type padType = tensor::PadOp::inferResultType(
|
||||||
|
self.getType().cast<RankedTensorType>(), staticLow, staticHigh);
|
||||||
|
Value paddedInput = rewriter.create<tensor::PadOp>(
|
||||||
|
loc, padType, self, lowPad, highPad, castedValue);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1955,12 +1955,6 @@ ONNX_XFAIL_SET = {
|
||||||
"OneHotModule_basic",
|
"OneHotModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Pad
|
# Failure - onnx_lowering: onnx.Pad
|
||||||
"ConstantPad2dStaticModule_basic",
|
|
||||||
"ConstantPadNdModule_basic",
|
|
||||||
"ConstantPadNdPartialStaticModule_basic",
|
|
||||||
"ConstantPadNdStaticModule_basic",
|
|
||||||
"PadModule_basic",
|
|
||||||
"PadWithNoneValModule_basic",
|
|
||||||
"ReflectionPad1dModule2dInput_Right",
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
"ReflectionPad1dModule2dInput_basic",
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
"ReflectionPad1dModule3dInput_Left",
|
"ReflectionPad1dModule3dInput_Left",
|
||||||
|
|
|
@ -109,5 +109,3 @@ class ReflectionPad2dModuleRight(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight())
|
||||||
def ReflectionPad2dModule_Right(module, tu: TestUtils):
|
def ReflectionPad2dModule_Right(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 3, 20, 20))
|
module.forward(tu.rand(2, 3, 20, 20))
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
|
@ -447,23 +447,23 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_pad
|
// CHECK-LABEL: func.func @test_pad
|
||||||
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
// CHECK: %[[LIST:.+]] = torch.prim.tolist(%[[SELECT_1]], %[[SELECT_3]], %[[SELECT_0]], %[[SELECT_2]]) : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.list<int>
|
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
||||||
// CHECK: %[[NONE:.+]] = torch.constant.none
|
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
|
|
||||||
// CHECK: %[[INT7:.+]] = torch.constant.int 7
|
|
||||||
// CHECK: %[[CONVERT:.+]] = torch.aten.to.dtype %arg2, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
|
||||||
// CHECK: %[[ITEM:.+]] = torch.aten.item %[[CONVERT]] : !torch.vtensor<[],f64> -> !torch.float
|
|
||||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
|
||||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
||||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
||||||
return %0 : !torch.vtensor<[5,4],f32>
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
|
@ -474,13 +474,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
||||||
// CHECK-LABEL: @test_pad_optional_constant
|
// CHECK-LABEL: @test_pad_optional_constant
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.constant.float 0
|
||||||
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
|
||||||
// CHECK: %[[SEVEN:.*]] = torch.constant.int 7
|
|
||||||
// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64>
|
|
||||||
// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float
|
|
||||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
|
||||||
|
|
||||||
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
|
Loading…
Reference in New Issue