Adds misc fixes for some padding related issues (#3528)

This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
pull/3541/head
zjgarvey 2024-07-11 18:01:45 -07:00 committed by GitHub
parent b38585e077
commit 0fb8b017d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 443 additions and 122 deletions

View File

@ -1343,11 +1343,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
SmallVector<Value> padsRearrange; SmallVector<Value> padsRearrange;
SmallVector<Value> inputPaddingList; SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) { for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>( padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr( binder.getLoc(), rewriter.getI64IntegerAttr(
padding[(padding.size() / 2) + i]))); padding[padding.size() / 2 - i - 1])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
inputPaddingList.emplace_back( inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>( rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0))); binder.getLoc(), rewriter.getI64IntegerAttr(0)));

View File

@ -2233,12 +2233,44 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "The axes parameter is not supported yet"); binder.op, "The axes parameter is not supported yet");
} }
if (binder.tensorOperandAtIndex(data, 0) || if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(pads, 1) ||
binder.tensorResultType(resultType) || binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "constant")) binder.customOpNameStringAttr(mode, "mode", "constant"))
return failure(); return failure();
bool cstMode = (mode == "constant");
// get input rank
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
TensorType dataTensor = dataOpTy.toBuiltinTensor();
if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure(
binder.op, "pad length unknown and data operand unranked");
int64_t dataRank = dataTensor.getRank();
int64_t padsSize = 2 * dataRank;
Location loc = binder.getLoc(); Location loc = binder.getLoc();
// get pads (earlier versions use an attribute, newer versions use a
// tensor input)
SmallVector<Value> padsTensorValue;
if (binder.tensorOperandAtIndex(pads, 1)) {
SmallVector<int64_t> defaultPads(2 * dataRank, 0);
SmallVector<int64_t> padInts;
if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads))
return rewriter.notifyMatchFailure(binder.op,
"pads binder failure");
// opset_version 1 uses the attribute name "paddings"
if (padInts == defaultPads) {
SmallVector<int64_t> paddingsInts;
if (binder.s64IntegerArrayAttr(paddingsInts, "paddings",
defaultPads))
return rewriter.notifyMatchFailure(binder.op,
"paddings binder failure");
padInts = paddingsInts;
}
for (auto p : padInts)
padsTensorValue.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(p)));
} else {
// Get pads shape and rank. The pads tensor is expected to be 1-D // Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor. // tensor.
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType()); auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
@ -2251,23 +2283,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (padsRank != 1) if (padsRank != 1)
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"expect 1-d pad tensor"); "expect 1-d pad tensor");
if (padsShape[0] != Torch::kUnknownSize) {
int64_t padsSize = padsShape[0];
if (padsSize == Torch::kUnknownSize) {
// As per onnx.Pad documentation, padSize = 2*num_data_axes // As per onnx.Pad documentation, padSize = 2*num_data_axes
// (if axes param not passed). Need to be updated when adding // (if axes param not passed). Need to be updated when adding
// support for `axes` param. // support for `axes` param.
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType()); padsSize = padsShape[0];
TensorType dataTensor = dataOpTy.toBuiltinTensor(); }
if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure( // Extract all the values of 1-D pad tensor and create a list of all
binder.op, "pad length unknown and data operand unranked"); // these values as torch.pad op expects pad list.
int64_t dataRank = dataTensor.getRank(); Value constZero = rewriter.create<Torch::ConstantIntOp>(
padsSize = 2 * dataRank; loc, rewriter.getI64IntegerAttr(0));
SmallVector<int64_t> emptyShape;
Type padsElemType = Torch::ValueTensorType::get(
padsTensorType.getContext(), emptyShape,
padsTensorType.getOptionalDtype());
for (uint32_t i = 0; i < padsSize; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);
padsTensorValue.push_back(selectInt);
}
} }
Value constantValue; Value constantValue;
if (binder.getNumOperands() >= 3) { if (binder.getNumOperands() >= 3 && cstMode) {
if (!binder.tensorOperandAtIndex(constantValue, 2)) { if (!binder.tensorOperandAtIndex(constantValue, 2)) {
auto constTy = auto constTy =
dyn_cast<Torch::BaseTensorType>(constantValue.getType()); dyn_cast<Torch::BaseTensorType>(constantValue.getType());
@ -2283,38 +2326,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
} }
} }
if (!constantValue) { if (!constantValue && cstMode) {
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType()); auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (isa<IntegerType>(dataTensorType.getDtype())) if (isa<IntegerType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>( constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0)); loc, rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(dataTensorType.getDtype())) // Earlier versions used a FLOAT attribute to store the constant
// value. The following will pick up on any non-default value attr if
// provided.
float constantFloat;
if (isa<FloatType>(dataTensorType.getDtype()) &&
!binder.f32FloatAttr(constantFloat, "value", 0.0f))
constantValue = rewriter.create<Torch::ConstantFloatOp>( constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f)); loc, rewriter.getF64FloatAttr(constantFloat));
if (!constantValue) if (!constantValue)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "expected integer or float data tensor"); binder.op, "expected integer or float data tensor");
} }
// Extract all the values of 1-D pad tensor and create a list of all // for modes other than "constant" a value is not required
// these values as torch.pad op expects pad list. if (!cstMode)
Value constZero = rewriter.create<Torch::ConstantIntOp>( constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);
loc, rewriter.getI64IntegerAttr(0));
SmallVector<Value> padsTensorValue;
SmallVector<int64_t> emptyShape;
Type padsElemType =
Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape,
padsTensorType.getOptionalDtype());
for (uint32_t i = 0; i < padsSize; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);
padsTensorValue.push_back(selectInt);
}
// The torch.pad op expects a different arrangement of padding pairs for // The torch.pad op expects a different arrangement of padding pairs for
// each dimension as compared to the onnx.pad op. Rearrange the pad // each dimension as compared to the onnx.pad op. Rearrange the pad
@ -2335,6 +2368,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Torch::ListType::get(rewriter.getType<Torch::IntType>()), Torch::ListType::get(rewriter.getType<Torch::IntType>()),
padsRearrange) padsRearrange)
.getResult(); .getResult();
// lowering to AtenConstantPadNdOp directly allows passing any torch
// scalar type for the value, whereas AtenPadOp takes an optional float
// type.
if (cstMode && !isa<Torch::NoneType>(constantValue.getType())) {
rewriter.replaceOpWithNewOp<Torch::AtenConstantPadNdOp>(
binder.op, resultType, data, padsSizeList, constantValue);
return success();
}
// translate a few mismatching mode names ONNX -> Torch
mode = (mode == "edge") ? "replicate" : mode;
mode = (mode == "wrap") ? "circular" : mode;
Value modeVal = rewriter.create<Torch::ConstantStrOp>( Value modeVal = rewriter.create<Torch::ConstantStrOp>(
loc, rewriter.getStringAttr(mode)); loc, rewriter.getStringAttr(mode));

View File

@ -97,8 +97,12 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = cast<RankedTensorType>(newResultType).getElementType(); Type elementType = cast<RankedTensorType>(newResultType).getElementType();
auto dstOriginalDtype =
cast<Torch::ValueTensorType>(op.getType()).getDtype();
Value castedValue = Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType,
std::nullopt, dstOriginalDtype);
Type padType = tensor::PadOp::inferResultType( Type padType = tensor::PadOp::inferResultType(
cast<RankedTensorType>(self.getType()), staticLow, staticHigh); cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
@ -209,26 +213,38 @@ public:
Value one = getConstant(rewriter, loc, 1, indexType); Value one = getConstant(rewriter, loc, 1, indexType);
Value hDimSizeMinusOne = createSub(hDimSize, one); Value hDimSizeMinusOne = createSub(hDimSize, one);
Value vDimSizeMinusOne = createSub(vDimSize, one); Value vDimSizeMinusOne = createSub(vDimSize, one);
SmallVector<Value> allOneStrides(numDims, one); SmallVector<Value> allOneStridesVal(numDims, one);
SmallVector<OpFoldResult> allOneStrides =
getAsOpFoldResult(allOneStridesVal);
SmallVector<Value> extractOffsetsLT(numDims, zero); SmallVector<Value> extractOffsetsLTVal(numDims, zero);
extractOffsetsLT[hDim] = zero; extractOffsetsLTVal[hDim] = zero;
extractOffsetsLT[vDim] = zero; extractOffsetsLTVal[vDim] = zero;
SmallVector<Value> extractShapeLR(numDims, one); SmallVector<OpFoldResult> extractOffsetsLT =
extractShapeLR[hDim] = one; getAsOpFoldResult(extractOffsetsLTVal);
extractShapeLR[vDim] = vDimSize; SmallVector<Value> extractShapeLRVal(numDims, one);
extractShapeLRVal[hDim] = one;
extractShapeLRVal[vDim] = vDimSize;
SmallVector<OpFoldResult> extractShapeLR =
getAsOpFoldResult(extractShapeLRVal);
SmallVector<Value> extractOffsetsRight(numDims, zero); SmallVector<Value> extractOffsetsRightVal(numDims, zero);
extractOffsetsRight[hDim] = hDimSizeMinusOne; extractOffsetsRightVal[hDim] = hDimSizeMinusOne;
extractOffsetsRight[vDim] = zero; extractOffsetsRightVal[vDim] = zero;
SmallVector<OpFoldResult> extractOffsetsRight =
getAsOpFoldResult(extractOffsetsRightVal);
SmallVector<Value> extractOffsetsBottom(numDims, zero); SmallVector<Value> extractOffsetsBottomVal(numDims, zero);
extractOffsetsBottom[hDim] = zero; extractOffsetsBottomVal[hDim] = zero;
extractOffsetsBottom[vDim] = vDimSizeMinusOne; extractOffsetsBottomVal[vDim] = vDimSizeMinusOne;
SmallVector<OpFoldResult> extractOffsetsBottom =
getAsOpFoldResult(extractOffsetsBottomVal);
SmallVector<Value> extractShapeTB(numDims, one); SmallVector<Value> extractShapeTBVal(numDims, one);
extractShapeTB[hDim] = hDimSize; extractShapeTBVal[hDim] = hDimSize;
extractShapeTB[vDim] = one; extractShapeTBVal[vDim] = one;
SmallVector<OpFoldResult> extractShapeTB =
getAsOpFoldResult(extractShapeTBVal);
SmallVector<Value> tensorsLeft; SmallVector<Value> tensorsLeft;
SmallVector<Value> tensorsRight; SmallVector<Value> tensorsRight;
@ -240,24 +256,26 @@ public:
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>( Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
Value vLeftSlice = vCenterLeftSlice; Value vLeftSlice = vCenterLeftSlice;
SmallVector<Value> extractIndices(numDims, zero);
if (hasTopPadding) { if (hasTopPadding) {
Value topLeftValue = rewriter.create<tensor::ExtractOp>( Value topLeftValue =
loc, input, ValueRange{zero, zero, zero, zero}); rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// pad vCenterLeftSlice on the top // pad vCenterLeftSlice on the top
SmallVector<int64_t> lowPadding(4, 0); SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(4, 0); SmallVector<int64_t> highPadding(numDims, 0);
lowPadding[2] = padInts[2]; lowPadding[vDim] = padInts[2];
vLeftSlice = torch_to_linalg::getPaddedTensor( vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
} }
if (hasBottomPadding) { if (hasBottomPadding) {
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>( extractIndices[vDim] = vDimSizeMinusOne;
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); Value bottomLeftValue =
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// pad vLeftSlice at the bottom // pad vLeftSlice at the bottom
SmallVector<int64_t> lowPadding(4, 0); SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(4, 0); SmallVector<int64_t> highPadding(numDims, 0);
highPadding[2] = padInts[3]; highPadding[vDim] = padInts[3];
vLeftSlice = torch_to_linalg::getPaddedTensor( vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
} }
@ -265,7 +283,7 @@ public:
tensorsLeft.push_back(vLeftSlice); tensorsLeft.push_back(vLeftSlice);
} }
Value leftPadTile = Value leftPadTile =
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft); rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsLeft);
tensorsRes.push_back(leftPadTile); tensorsRes.push_back(leftPadTile);
} }
if (hasTopPadding) { if (hasTopPadding) {
@ -283,33 +301,35 @@ public:
tensorsCenter.push_back(bottomHcenterSlice); tensorsCenter.push_back(bottomHcenterSlice);
} }
} }
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter); centerTile = rewriter.create<tensor::ConcatOp>(loc, vDim, tensorsCenter);
tensorsRes.push_back(centerTile); tensorsRes.push_back(centerTile);
if (hasRightPadding) { if (hasRightPadding) {
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>( Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
Value vRightSlice = vCenterRightSlice; Value vRightSlice = vCenterRightSlice;
SmallVector<Value> extractIndices(numDims, zero);
extractIndices[hDim] = hDimSizeMinusOne;
if (hasTopPadding) { if (hasTopPadding) {
Value topRightValue = rewriter.create<tensor::ExtractOp>( Value topRightValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
// pad vCenterRightSlice on the top // pad vCenterRightSlice on the top
SmallVector<int64_t> lowPadding(4, 0); SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(4, 0); SmallVector<int64_t> highPadding(numDims, 0);
lowPadding[2] = padInts[2]; lowPadding[vDim] = padInts[2];
vRightSlice = torch_to_linalg::getPaddedTensor( vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
} }
if (hasBottomPadding) { if (hasBottomPadding) {
Value bottomRightValue = rewriter.create<tensor::ExtractOp>( extractIndices[vDim] = vDimSizeMinusOne;
loc, input, Value bottomRightValue =
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
SmallVector<int64_t> lowPadding(4, 0); SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(4, 0); SmallVector<int64_t> highPadding(numDims, 0);
highPadding[2] = padInts[3]; highPadding[vDim] = padInts[3];
vRightSlice = torch_to_linalg::getPaddedTensor( vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding, op, rewriter, vRightSlice, lowPadding, highPadding,
bottomRightValue); bottomRightValue);
@ -318,10 +338,10 @@ public:
tensorsRight.push_back(vRightSlice); tensorsRight.push_back(vRightSlice);
} }
Value rightPadTile = Value rightPadTile =
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight); rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRight);
tensorsRes.push_back(rightPadTile); tensorsRes.push_back(rightPadTile);
} }
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes); Value resTensor = rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRes);
Type newResultType = getTypeConverter()->convertType(op.getType()); Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
return success(); return success();

View File

@ -6379,7 +6379,11 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPadOp op, LogicalResult matchAndRewrite(AtenPadOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
std::string mode;
if (!matchPattern(op.getMode(), m_TorchConstantStr(mode)))
return rewriter.notifyMatchFailure(op, "mode must be a constant string");
if (mode == "constant") {
Value value = op.getValue(); Value value = op.getValue();
if (isa<Torch::OptionalType>(value.getType())) if (isa<Torch::OptionalType>(value.getType()))
return rewriter.notifyMatchFailure(op, "optional type not supported"); return rewriter.notifyMatchFailure(op, "optional type not supported");
@ -6391,6 +6395,76 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
op, op.getType(), op.getSelf(), op.getPad(), value); op, op.getType(), op.getSelf(), op.getPad(), value);
return success(); return success();
} }
SmallVector<Value> padValues;
if (!getListConstructElements(op.getPad(), padValues))
return failure();
SmallVector<int64_t> padInts;
Value usefulPads = op.getPad();
uint64_t usefulPadIndexEnd = padValues.size();
// try to reduce the number of padding dims if possible
if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) {
if ((padInts.size() % 2) == 1)
return rewriter.notifyMatchFailure(op,
"expected an even number of pads");
for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) {
if (padInts[i] != 0 || padInts[i - 1] != 0)
break;
usefulPadIndexEnd = i - 1;
}
if (usefulPadIndexEnd == 0) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
}
// we don't have support for 1-D replicate pad, so pass it as 2d if
// possible.
// TODO: add support for AtenReplicatePad1dOp and remove this.
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
usefulPadIndexEnd = 4;
// make a new list of padding ints if dimensionality reduction can be
// performed
if (usefulPadIndexEnd < padValues.size()) {
ArrayRef<Value> usefulPadValues(padValues.begin(),
padValues.begin() + usefulPadIndexEnd);
usefulPads = rewriter.create<PrimListConstructOp>(
op.getLoc(),
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
usefulPadValues);
}
uint64_t numPadDims = usefulPadIndexEnd / 2;
if (mode == "reflect") {
// only support for relectionpad 1d and 2d
if (numPadDims == 2) {
rewriter.replaceOpWithNewOp<AtenReflectionPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
if (numPadDims == 1) {
rewriter.replaceOpWithNewOp<AtenReflectionPad1dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
return failure();
}
if (mode == "replicate") {
// only support for replication pad 2d
if (numPadDims != 2)
return failure();
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode);
}
}; };
} // namespace } // namespace

View File

@ -40,7 +40,8 @@ bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a // if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards // RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp, return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
PrimsCollapseOp, AtenViewOp>(op); PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
op);
} }
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant // The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
@ -65,7 +66,7 @@ public:
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) { for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i]; Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack; std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd; Value dequantOpd, MPTQTOpd, scale, zeroPoint;
for (unsigned k = 0; k < depth + 1; k++) { for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp(); auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument) // Case 0 : currOp is a nullptr (e.g., operand is a block argument)
@ -84,6 +85,8 @@ public:
auto MPTQTOp = auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>(); dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0); MPTQTOpd = MPTQTOp.getOperand(0);
scale = MPTQTOp.getOperand(1);
zeroPoint = MPTQTOp.getOperand(2);
} }
// either a dequant was found or chain broken, so break loop // either a dequant was found or chain broken, so break loop
break; break;
@ -107,6 +110,47 @@ public:
commutingOpStack.pop(); commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands()); llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd; currOperands[0] = oldOpd;
// pad ops aren't quite commuting, so we include some extra logic to
// quantize the padding value
if (isa<Torch::AtenPadOp, Torch::AtenConstantPadNdOp>(currOp)) {
Value floatPadValue = currOperands.back();
Value quantPadValue;
if (isa<Torch::NoneType>(floatPadValue.getType()))
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
else {
floatPadValue =
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
loc, floatPadValue, scale);
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
loc, quantPadValue, zeroPoint);
}
// clamp pad value to qint range
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
bool isSigned = intType.isSignedInteger();
int64_t width = intType.getWidth();
assert(width < 64 &&
"quantized int bitwidth should be less than 64");
int64_t minInt = isSigned ? -(1 << (width - 1)) : 0;
int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1);
Value minQValueFloat = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(minInt));
Value maxQValueFloat = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(maxInt));
SmallVector<int64_t> emptyShape;
auto floatTensorType = rewriter.getType<Torch::ValueTensorType>(
emptyShape, rewriter.getF64Type());
Value quantPadValueTensor = createRank0Tensor(
rewriter, loc, floatTensorType, quantPadValue);
Value clampedTensor = rewriter.create<Torch::AtenClampOp>(
loc, floatTensorType, quantPadValueTensor, minQValueFloat,
maxQValueFloat);
quantPadValue = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), clampedTensor);
}
// quantPadValue is a float, but will get converted/truncated
currOperands.back() = quantPadValue;
}
// get new result type // get new result type
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]); auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
auto intType = auto intType =
@ -374,7 +418,8 @@ public:
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>, RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>, RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>, RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
RemoveUnused<AtenViewOp>, RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
RemoveUnused<AtenConstantPadNdOp>,
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>, QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>, QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>, QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,

View File

@ -1014,6 +1014,38 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>,
// ----- // -----
// CHECK-LABEL: @test_conv_with_asymmetric_padding
func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: %[[int2_1:.*]] = torch.constant.int 2
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[str:.*]] = torch.constant.str "constant"
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32>
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
// CHECK: return %[[Conv]]
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32>
return %0 : !torch.vtensor<[1,1,4,3],f32>
}
// -----
// CHECK-LABEL: @test_conv_with_bias_strides_padding // CHECK-LABEL: @test_conv_with_bias_strides_padding
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3:.*]] = torch.constant.int 3

View File

@ -815,32 +815,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t
// ----- // -----
// CHECK-LABEL: @test_grid_sampler03 // CHECK-LABEL: func.func @test_oldest_pad
// CHECK: %[[INT1:.*]] = torch.constant.int 1 func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} {
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK: %[[B0:.*]] = torch.constant.bool true // CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> // CHECK: %[[int2:.*]] = torch.constant.int 2
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[int0_1:.*]] = torch.constant.int 0
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
return %0 : !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
} }
// ----- // -----
// CHECK-LABEL: func.func @test_less_or_equal // CHECK-LABEL: func.func @test_old_pad
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> // CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> // CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> // CHECK: %[[int2:.*]] = torch.constant.int 2
%0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> // CHECK: %[[int0_1:.*]] = torch.constant.int 0
return %0 : !torch.vtensor<[3,4,5],i1> // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
} }
// ----- // -----
// CHECK-LABEL: func.func @test_pad // CHECK-LABEL: func.func @test_pad
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
@ -854,9 +862,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STR:.+]] = torch.constant.str "constant" // CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32> return %0 : !torch.vtensor<[5,4],f32>
@ -864,12 +872,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// ----- // -----
// CHECK-LABEL: func.func @test_i32pad
func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list<int>, !torch.int -> !torch.vtensor<[5,4],si32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32>
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32>
return %0 : !torch.vtensor<[5,4],si32>
}
// -----
// CHECK-LABEL: @test_pad_optional_constant // CHECK-LABEL: @test_pad_optional_constant
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.float 0 // CHECK: %[[VAL:.+]] = torch.constant.float 0
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" // CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
@ -878,6 +910,34 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !
// ----- // -----
// CHECK-LABEL: @test_pad_wrap
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "circular"
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: @test_pad_edge
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "replicate"
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_pow // CHECK-LABEL: func.func @test_pow
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>

View File

@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
// ----- // -----
// CHECK-LABEL: func.func @mm_pad_commute
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[none:.*]] = torch.constant.none
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
%scale = torch.constant.float 0.5
%false = torch.constant.bool false
%zero = torch.constant.int 0
%one = torch.constant.int 1
%two = torch.constant.int 2
%floatpad = torch.constant.float 3.5
%zp = torch.constant.int -128
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8>
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32>
%list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%str = torch.constant.str "constant"
%pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32>
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32>
%16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32>
return %16 : !torch.vtensor<[9,4],f32>
}
// -----
// CHECK-LABEL: @convolution_bias // CHECK-LABEL: @convolution_bias
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
%scale = torch.constant.float 0.5 %scale = torch.constant.float 0.5