mirror of https://github.com/llvm/torch-mlir
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes: 1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457> 2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442> 3. Fixes the padding order for asymmetrically padded onnx.Conv ops 4. Enables passing quantization through those onnx.Conv op pre-paddings 5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to enable support for input rank != 4 Unfortunately, even with all of these changes, the e2e tests for the ReplicationPad2d still fail the onnx config, since the torch export procedure for rearranging the pad order is complicated enough that the padding ints end up not being able to fold back to constants.pull/3541/head
parent
b38585e077
commit
0fb8b017d8
|
@ -1343,11 +1343,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
SmallVector<Value> padsRearrange;
|
||||
SmallVector<Value> inputPaddingList;
|
||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||
padding[(padding.size() / 2) + i])));
|
||||
padding[padding.size() / 2 - i - 1])));
|
||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
|
||||
inputPaddingList.emplace_back(
|
||||
rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||
|
|
|
@ -2233,41 +2233,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.op, "The axes parameter is not supported yet");
|
||||
}
|
||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||
binder.tensorOperandAtIndex(pads, 1) ||
|
||||
binder.tensorResultType(resultType) ||
|
||||
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
||||
return failure();
|
||||
bool cstMode = (mode == "constant");
|
||||
|
||||
// get input rank
|
||||
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
||||
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
||||
if (!dataTensor || !dataTensor.hasRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "pad length unknown and data operand unranked");
|
||||
int64_t dataRank = dataTensor.getRank();
|
||||
int64_t padsSize = 2 * dataRank;
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
|
||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||
// tensor.
|
||||
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
||||
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expect non empty pad tensor");
|
||||
}
|
||||
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
||||
int64_t padsRank = padsShape.size();
|
||||
if (padsRank != 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"expect 1-d pad tensor");
|
||||
// get pads (earlier versions use an attribute, newer versions use a
|
||||
// tensor input)
|
||||
SmallVector<Value> padsTensorValue;
|
||||
if (binder.tensorOperandAtIndex(pads, 1)) {
|
||||
SmallVector<int64_t> defaultPads(2 * dataRank, 0);
|
||||
SmallVector<int64_t> padInts;
|
||||
if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"pads binder failure");
|
||||
// opset_version 1 uses the attribute name "paddings"
|
||||
if (padInts == defaultPads) {
|
||||
SmallVector<int64_t> paddingsInts;
|
||||
if (binder.s64IntegerArrayAttr(paddingsInts, "paddings",
|
||||
defaultPads))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"paddings binder failure");
|
||||
padInts = paddingsInts;
|
||||
}
|
||||
for (auto p : padInts)
|
||||
padsTensorValue.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(p)));
|
||||
} else {
|
||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||
// tensor.
|
||||
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
||||
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Expect non empty pad tensor");
|
||||
}
|
||||
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
||||
int64_t padsRank = padsShape.size();
|
||||
if (padsRank != 1)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"expect 1-d pad tensor");
|
||||
if (padsShape[0] != Torch::kUnknownSize) {
|
||||
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
||||
// (if axes param not passed). Need to be updated when adding
|
||||
// support for `axes` param.
|
||||
padsSize = padsShape[0];
|
||||
}
|
||||
|
||||
int64_t padsSize = padsShape[0];
|
||||
if (padsSize == Torch::kUnknownSize) {
|
||||
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
||||
// (if axes param not passed). Need to be updated when adding
|
||||
// support for `axes` param.
|
||||
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
||||
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
||||
if (!dataTensor || !dataTensor.hasRank())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "pad length unknown and data operand unranked");
|
||||
int64_t dataRank = dataTensor.getRank();
|
||||
padsSize = 2 * dataRank;
|
||||
// Extract all the values of 1-D pad tensor and create a list of all
|
||||
// these values as torch.pad op expects pad list.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<int64_t> emptyShape;
|
||||
Type padsElemType = Torch::ValueTensorType::get(
|
||||
padsTensorType.getContext(), emptyShape,
|
||||
padsTensorType.getOptionalDtype());
|
||||
for (uint32_t i = 0; i < padsSize; ++i) {
|
||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, padsElemType, pads, constZero, index);
|
||||
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), select);
|
||||
padsTensorValue.push_back(selectInt);
|
||||
}
|
||||
}
|
||||
|
||||
Value constantValue;
|
||||
if (binder.getNumOperands() >= 3) {
|
||||
if (binder.getNumOperands() >= 3 && cstMode) {
|
||||
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||
auto constTy =
|
||||
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
|
||||
|
@ -2283,38 +2326,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
}
|
||||
|
||||
if (!constantValue) {
|
||||
if (!constantValue && cstMode) {
|
||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||
if (isa<IntegerType>(dataTensorType.getDtype()))
|
||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
if (isa<FloatType>(dataTensorType.getDtype()))
|
||||
// Earlier versions used a FLOAT attribute to store the constant
|
||||
// value. The following will pick up on any non-default value attr if
|
||||
// provided.
|
||||
float constantFloat;
|
||||
if (isa<FloatType>(dataTensorType.getDtype()) &&
|
||||
!binder.f32FloatAttr(constantFloat, "value", 0.0f))
|
||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.0f));
|
||||
loc, rewriter.getF64FloatAttr(constantFloat));
|
||||
|
||||
if (!constantValue)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "expected integer or float data tensor");
|
||||
}
|
||||
|
||||
// Extract all the values of 1-D pad tensor and create a list of all
|
||||
// these values as torch.pad op expects pad list.
|
||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
SmallVector<Value> padsTensorValue;
|
||||
SmallVector<int64_t> emptyShape;
|
||||
Type padsElemType =
|
||||
Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape,
|
||||
padsTensorType.getOptionalDtype());
|
||||
for (uint32_t i = 0; i < padsSize; ++i) {
|
||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||
loc, padsElemType, pads, constZero, index);
|
||||
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::IntType>(), select);
|
||||
padsTensorValue.push_back(selectInt);
|
||||
}
|
||||
// for modes other than "constant" a value is not required
|
||||
if (!cstMode)
|
||||
constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
||||
// The torch.pad op expects a different arrangement of padding pairs for
|
||||
// each dimension as compared to the onnx.pad op. Rearrange the pad
|
||||
|
@ -2335,6 +2368,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||
padsRearrange)
|
||||
.getResult();
|
||||
|
||||
// lowering to AtenConstantPadNdOp directly allows passing any torch
|
||||
// scalar type for the value, whereas AtenPadOp takes an optional float
|
||||
// type.
|
||||
if (cstMode && !isa<Torch::NoneType>(constantValue.getType())) {
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenConstantPadNdOp>(
|
||||
binder.op, resultType, data, padsSizeList, constantValue);
|
||||
return success();
|
||||
}
|
||||
|
||||
// translate a few mismatching mode names ONNX -> Torch
|
||||
mode = (mode == "edge") ? "replicate" : mode;
|
||||
mode = (mode == "wrap") ? "circular" : mode;
|
||||
|
||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||
loc, rewriter.getStringAttr(mode));
|
||||
|
||||
|
|
|
@ -97,8 +97,12 @@ public:
|
|||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
|
||||
|
||||
auto dstOriginalDtype =
|
||||
cast<Torch::ValueTensorType>(op.getType()).getDtype();
|
||||
Value castedValue =
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType,
|
||||
std::nullopt, dstOriginalDtype);
|
||||
|
||||
Type padType = tensor::PadOp::inferResultType(
|
||||
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
|
||||
|
@ -209,26 +213,38 @@ public:
|
|||
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||
Value hDimSizeMinusOne = createSub(hDimSize, one);
|
||||
Value vDimSizeMinusOne = createSub(vDimSize, one);
|
||||
SmallVector<Value> allOneStrides(numDims, one);
|
||||
SmallVector<Value> allOneStridesVal(numDims, one);
|
||||
SmallVector<OpFoldResult> allOneStrides =
|
||||
getAsOpFoldResult(allOneStridesVal);
|
||||
|
||||
SmallVector<Value> extractOffsetsLT(numDims, zero);
|
||||
extractOffsetsLT[hDim] = zero;
|
||||
extractOffsetsLT[vDim] = zero;
|
||||
SmallVector<Value> extractShapeLR(numDims, one);
|
||||
extractShapeLR[hDim] = one;
|
||||
extractShapeLR[vDim] = vDimSize;
|
||||
SmallVector<Value> extractOffsetsLTVal(numDims, zero);
|
||||
extractOffsetsLTVal[hDim] = zero;
|
||||
extractOffsetsLTVal[vDim] = zero;
|
||||
SmallVector<OpFoldResult> extractOffsetsLT =
|
||||
getAsOpFoldResult(extractOffsetsLTVal);
|
||||
SmallVector<Value> extractShapeLRVal(numDims, one);
|
||||
extractShapeLRVal[hDim] = one;
|
||||
extractShapeLRVal[vDim] = vDimSize;
|
||||
SmallVector<OpFoldResult> extractShapeLR =
|
||||
getAsOpFoldResult(extractShapeLRVal);
|
||||
|
||||
SmallVector<Value> extractOffsetsRight(numDims, zero);
|
||||
extractOffsetsRight[hDim] = hDimSizeMinusOne;
|
||||
extractOffsetsRight[vDim] = zero;
|
||||
SmallVector<Value> extractOffsetsRightVal(numDims, zero);
|
||||
extractOffsetsRightVal[hDim] = hDimSizeMinusOne;
|
||||
extractOffsetsRightVal[vDim] = zero;
|
||||
SmallVector<OpFoldResult> extractOffsetsRight =
|
||||
getAsOpFoldResult(extractOffsetsRightVal);
|
||||
|
||||
SmallVector<Value> extractOffsetsBottom(numDims, zero);
|
||||
extractOffsetsBottom[hDim] = zero;
|
||||
extractOffsetsBottom[vDim] = vDimSizeMinusOne;
|
||||
SmallVector<Value> extractOffsetsBottomVal(numDims, zero);
|
||||
extractOffsetsBottomVal[hDim] = zero;
|
||||
extractOffsetsBottomVal[vDim] = vDimSizeMinusOne;
|
||||
SmallVector<OpFoldResult> extractOffsetsBottom =
|
||||
getAsOpFoldResult(extractOffsetsBottomVal);
|
||||
|
||||
SmallVector<Value> extractShapeTB(numDims, one);
|
||||
extractShapeTB[hDim] = hDimSize;
|
||||
extractShapeTB[vDim] = one;
|
||||
SmallVector<Value> extractShapeTBVal(numDims, one);
|
||||
extractShapeTBVal[hDim] = hDimSize;
|
||||
extractShapeTBVal[vDim] = one;
|
||||
SmallVector<OpFoldResult> extractShapeTB =
|
||||
getAsOpFoldResult(extractShapeTBVal);
|
||||
|
||||
SmallVector<Value> tensorsLeft;
|
||||
SmallVector<Value> tensorsRight;
|
||||
|
@ -240,24 +256,26 @@ public:
|
|||
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
|
||||
Value vLeftSlice = vCenterLeftSlice;
|
||||
SmallVector<Value> extractIndices(numDims, zero);
|
||||
if (hasTopPadding) {
|
||||
Value topLeftValue = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{zero, zero, zero, zero});
|
||||
Value topLeftValue =
|
||||
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||
// pad vCenterLeftSlice on the top
|
||||
SmallVector<int64_t> lowPadding(4, 0);
|
||||
SmallVector<int64_t> highPadding(4, 0);
|
||||
lowPadding[2] = padInts[2];
|
||||
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||
SmallVector<int64_t> highPadding(numDims, 0);
|
||||
lowPadding[vDim] = padInts[2];
|
||||
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
|
||||
}
|
||||
if (hasBottomPadding) {
|
||||
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
|
||||
extractIndices[vDim] = vDimSizeMinusOne;
|
||||
Value bottomLeftValue =
|
||||
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||
|
||||
// pad vLeftSlice at the bottom
|
||||
SmallVector<int64_t> lowPadding(4, 0);
|
||||
SmallVector<int64_t> highPadding(4, 0);
|
||||
highPadding[2] = padInts[3];
|
||||
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||
SmallVector<int64_t> highPadding(numDims, 0);
|
||||
highPadding[vDim] = padInts[3];
|
||||
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
|
||||
}
|
||||
|
@ -265,7 +283,7 @@ public:
|
|||
tensorsLeft.push_back(vLeftSlice);
|
||||
}
|
||||
Value leftPadTile =
|
||||
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft);
|
||||
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsLeft);
|
||||
tensorsRes.push_back(leftPadTile);
|
||||
}
|
||||
if (hasTopPadding) {
|
||||
|
@ -283,33 +301,35 @@ public:
|
|||
tensorsCenter.push_back(bottomHcenterSlice);
|
||||
}
|
||||
}
|
||||
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter);
|
||||
centerTile = rewriter.create<tensor::ConcatOp>(loc, vDim, tensorsCenter);
|
||||
tensorsRes.push_back(centerTile);
|
||||
|
||||
if (hasRightPadding) {
|
||||
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
|
||||
Value vRightSlice = vCenterRightSlice;
|
||||
SmallVector<Value> extractIndices(numDims, zero);
|
||||
extractIndices[hDim] = hDimSizeMinusOne;
|
||||
if (hasTopPadding) {
|
||||
Value topRightValue = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
|
||||
|
||||
// pad vCenterRightSlice on the top
|
||||
SmallVector<int64_t> lowPadding(4, 0);
|
||||
SmallVector<int64_t> highPadding(4, 0);
|
||||
lowPadding[2] = padInts[2];
|
||||
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||
SmallVector<int64_t> highPadding(numDims, 0);
|
||||
lowPadding[vDim] = padInts[2];
|
||||
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
|
||||
}
|
||||
if (hasBottomPadding) {
|
||||
Value bottomRightValue = rewriter.create<tensor::ExtractOp>(
|
||||
loc, input,
|
||||
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
|
||||
extractIndices[vDim] = vDimSizeMinusOne;
|
||||
Value bottomRightValue =
|
||||
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||
|
||||
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
|
||||
SmallVector<int64_t> lowPadding(4, 0);
|
||||
SmallVector<int64_t> highPadding(4, 0);
|
||||
highPadding[2] = padInts[3];
|
||||
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||
SmallVector<int64_t> highPadding(numDims, 0);
|
||||
highPadding[vDim] = padInts[3];
|
||||
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||
op, rewriter, vRightSlice, lowPadding, highPadding,
|
||||
bottomRightValue);
|
||||
|
@ -318,10 +338,10 @@ public:
|
|||
tensorsRight.push_back(vRightSlice);
|
||||
}
|
||||
Value rightPadTile =
|
||||
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
|
||||
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRight);
|
||||
tensorsRes.push_back(rightPadTile);
|
||||
}
|
||||
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
|
||||
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRes);
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
|
||||
return success();
|
||||
|
|
|
@ -6379,17 +6379,91 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenPadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
std::string mode;
|
||||
if (!matchPattern(op.getMode(), m_TorchConstantStr(mode)))
|
||||
return rewriter.notifyMatchFailure(op, "mode must be a constant string");
|
||||
|
||||
Value value = op.getValue();
|
||||
if (isa<Torch::OptionalType>(value.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||
if (isa<Torch::NoneType>(value.getType()))
|
||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||
if (mode == "constant") {
|
||||
Value value = op.getValue();
|
||||
if (isa<Torch::OptionalType>(value.getType()))
|
||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||
if (isa<Torch::NoneType>(value.getType()))
|
||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
||||
op, op.getType(), op.getSelf(), op.getPad(), value);
|
||||
return success();
|
||||
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
||||
op, op.getType(), op.getSelf(), op.getPad(), value);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> padValues;
|
||||
if (!getListConstructElements(op.getPad(), padValues))
|
||||
return failure();
|
||||
SmallVector<int64_t> padInts;
|
||||
Value usefulPads = op.getPad();
|
||||
uint64_t usefulPadIndexEnd = padValues.size();
|
||||
|
||||
// try to reduce the number of padding dims if possible
|
||||
if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) {
|
||||
if ((padInts.size() % 2) == 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"expected an even number of pads");
|
||||
|
||||
for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) {
|
||||
if (padInts[i] != 0 || padInts[i - 1] != 0)
|
||||
break;
|
||||
usefulPadIndexEnd = i - 1;
|
||||
}
|
||||
if (usefulPadIndexEnd == 0) {
|
||||
rewriter.replaceOp(op, op.getSelf());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// we don't have support for 1-D replicate pad, so pass it as 2d if
|
||||
// possible.
|
||||
// TODO: add support for AtenReplicatePad1dOp and remove this.
|
||||
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
|
||||
usefulPadIndexEnd = 4;
|
||||
|
||||
// make a new list of padding ints if dimensionality reduction can be
|
||||
// performed
|
||||
if (usefulPadIndexEnd < padValues.size()) {
|
||||
ArrayRef<Value> usefulPadValues(padValues.begin(),
|
||||
padValues.begin() + usefulPadIndexEnd);
|
||||
usefulPads = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||
usefulPadValues);
|
||||
}
|
||||
|
||||
uint64_t numPadDims = usefulPadIndexEnd / 2;
|
||||
|
||||
if (mode == "reflect") {
|
||||
// only support for relectionpad 1d and 2d
|
||||
if (numPadDims == 2) {
|
||||
rewriter.replaceOpWithNewOp<AtenReflectionPad2dOp>(
|
||||
op, op.getType(), op.getSelf(), usefulPads);
|
||||
return success();
|
||||
}
|
||||
if (numPadDims == 1) {
|
||||
rewriter.replaceOpWithNewOp<AtenReflectionPad1dOp>(
|
||||
op, op.getType(), op.getSelf(), usefulPads);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (mode == "replicate") {
|
||||
// only support for replication pad 2d
|
||||
if (numPadDims != 2)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
|
||||
op, op.getType(), op.getSelf(), usefulPads);
|
||||
return success();
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -40,7 +40,8 @@ bool isQCommutingOp(mlir::Operation *op) {
|
|||
// if adding a new commuting op here, be sure to add a
|
||||
// RemoveUnused pattern for that op to clean up afterwards
|
||||
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
|
||||
PrimsCollapseOp, AtenViewOp>(op);
|
||||
PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
|
||||
op);
|
||||
}
|
||||
|
||||
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
|
||||
|
@ -65,7 +66,7 @@ public:
|
|||
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
|
||||
Value operand = operands[i];
|
||||
std::stack<mlir::Operation *> commutingOpStack;
|
||||
Value dequantOpd, MPTQTOpd;
|
||||
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
|
||||
for (unsigned k = 0; k < depth + 1; k++) {
|
||||
auto currOp = operand.getDefiningOp();
|
||||
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
|
||||
|
@ -84,6 +85,8 @@ public:
|
|||
auto MPTQTOp =
|
||||
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
|
||||
MPTQTOpd = MPTQTOp.getOperand(0);
|
||||
scale = MPTQTOp.getOperand(1);
|
||||
zeroPoint = MPTQTOp.getOperand(2);
|
||||
}
|
||||
// either a dequant was found or chain broken, so break loop
|
||||
break;
|
||||
|
@ -107,6 +110,47 @@ public:
|
|||
commutingOpStack.pop();
|
||||
llvm::SmallVector<Value> currOperands(currOp->getOperands());
|
||||
currOperands[0] = oldOpd;
|
||||
// pad ops aren't quite commuting, so we include some extra logic to
|
||||
// quantize the padding value
|
||||
if (isa<Torch::AtenPadOp, Torch::AtenConstantPadNdOp>(currOp)) {
|
||||
Value floatPadValue = currOperands.back();
|
||||
Value quantPadValue;
|
||||
if (isa<Torch::NoneType>(floatPadValue.getType()))
|
||||
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
|
||||
else {
|
||||
floatPadValue =
|
||||
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
|
||||
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
|
||||
loc, floatPadValue, scale);
|
||||
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
|
||||
loc, quantPadValue, zeroPoint);
|
||||
}
|
||||
// clamp pad value to qint range
|
||||
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
|
||||
bool isSigned = intType.isSignedInteger();
|
||||
int64_t width = intType.getWidth();
|
||||
assert(width < 64 &&
|
||||
"quantized int bitwidth should be less than 64");
|
||||
int64_t minInt = isSigned ? -(1 << (width - 1)) : 0;
|
||||
int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1);
|
||||
Value minQValueFloat = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(minInt));
|
||||
Value maxQValueFloat = rewriter.create<ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(maxInt));
|
||||
SmallVector<int64_t> emptyShape;
|
||||
auto floatTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||
emptyShape, rewriter.getF64Type());
|
||||
Value quantPadValueTensor = createRank0Tensor(
|
||||
rewriter, loc, floatTensorType, quantPadValue);
|
||||
Value clampedTensor = rewriter.create<Torch::AtenClampOp>(
|
||||
loc, floatTensorType, quantPadValueTensor, minQValueFloat,
|
||||
maxQValueFloat);
|
||||
quantPadValue = rewriter.create<Torch::AtenItemOp>(
|
||||
loc, rewriter.getType<Torch::FloatType>(), clampedTensor);
|
||||
}
|
||||
// quantPadValue is a float, but will get converted/truncated
|
||||
currOperands.back() = quantPadValue;
|
||||
}
|
||||
// get new result type
|
||||
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
|
||||
auto intType =
|
||||
|
@ -374,7 +418,8 @@ public:
|
|||
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
||||
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
|
||||
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
|
||||
RemoveUnused<AtenViewOp>,
|
||||
RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
|
||||
RemoveUnused<AtenConstantPadNdOp>,
|
||||
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
||||
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
||||
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
||||
|
|
|
@ -1014,6 +1014,38 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_conv_with_asymmetric_padding
|
||||
func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int2_1:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[str:.*]] = torch.constant.str "constant"
|
||||
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
|
||||
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32>
|
||||
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
|
||||
// CHECK: return %[[Conv]]
|
||||
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32>
|
||||
return %0 : !torch.vtensor<[1,1,4,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_conv_with_bias_strides_padding
|
||||
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||
|
|
|
@ -815,32 +815,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_grid_sampler03
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[B0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
|
||||
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK-LABEL: func.func @test_oldest_pad
|
||||
func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} {
|
||||
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
|
||||
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_less_or_equal
|
||||
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
// CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
|
||||
%0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
||||
return %0 : !torch.vtensor<[3,4,5],i1>
|
||||
// CHECK-LABEL: func.func @test_old_pad
|
||||
func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} {
|
||||
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
|
||||
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_pad
|
||||
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
|
@ -854,9 +862,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
|||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
|
@ -864,12 +872,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_i32pad
|
||||
func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
|
||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list<int>, !torch.int -> !torch.vtensor<[5,4],si32>
|
||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32>
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32>
|
||||
return %0 : !torch.vtensor<[5,4],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_pad_optional_constant
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[VAL:.+]] = torch.constant.float 0
|
||||
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
// CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||
|
||||
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||
|
@ -878,6 +910,34 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_pad_wrap
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[VAL:.+]] = torch.constant.none
|
||||
// CHECK: %[[STR:.+]] = torch.constant.str "circular"
|
||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
|
||||
|
||||
func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_pad_edge
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||
// CHECK: %[[VAL:.+]] = torch.constant.none
|
||||
// CHECK: %[[STR:.+]] = torch.constant.str "replicate"
|
||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
|
||||
|
||||
func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||
return %0 : !torch.vtensor<[5,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_pow
|
||||
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||
|
|
|
@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @mm_pad_commute
|
||||
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
|
||||
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
|
||||
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
|
||||
// CHECK-DAG: %[[none:.*]] = torch.constant.none
|
||||
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
|
||||
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
|
||||
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
|
||||
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
|
||||
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
|
||||
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
|
||||
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
|
||||
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
|
||||
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
|
||||
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
|
||||
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
|
||||
%scale = torch.constant.float 0.5
|
||||
%false = torch.constant.bool false
|
||||
%zero = torch.constant.int 0
|
||||
%one = torch.constant.int 1
|
||||
%two = torch.constant.int 2
|
||||
%floatpad = torch.constant.float 3.5
|
||||
%zp = torch.constant.int -128
|
||||
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8>
|
||||
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32>
|
||||
%list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%str = torch.constant.str "constant"
|
||||
%pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32>
|
||||
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
|
||||
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32>
|
||||
%16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32>
|
||||
return %16 : !torch.vtensor<[9,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convolution_bias
|
||||
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||
%scale = torch.constant.float 0.5
|
||||
|
|
Loading…
Reference in New Issue