Adds misc fixes for some padding related issues (#3528)

This patch adds a few misc pad op related changes:

1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457>
2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442>
3. Fixes the padding order for asymmetrically padded onnx.Conv ops
4. Enables passing quantization through those onnx.Conv op pre-paddings
5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to
enable support for input rank != 4

Unfortunately, even with all of these changes, the e2e tests for the
ReplicationPad2d still fail the onnx config, since the torch export
procedure for rearranging the pad order is complicated enough that the
padding ints end up not being able to fold back to constants.
pull/3541/head
zjgarvey 2024-07-11 18:01:45 -07:00 committed by GitHub
parent b38585e077
commit 0fb8b017d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 443 additions and 122 deletions

View File

@ -1343,11 +1343,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
SmallVector<Value> padsRearrange;
SmallVector<Value> inputPaddingList;
for (uint32_t i = 0; i < padding.size() / 2; i++) {
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(
padding[(padding.size() / 2) + i])));
padding[padding.size() / 2 - i - 1])));
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(),
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
inputPaddingList.emplace_back(
rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0)));

View File

@ -2233,41 +2233,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "The axes parameter is not supported yet");
}
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(pads, 1) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "constant"))
return failure();
bool cstMode = (mode == "constant");
// get input rank
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
TensorType dataTensor = dataOpTy.toBuiltinTensor();
if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure(
binder.op, "pad length unknown and data operand unranked");
int64_t dataRank = dataTensor.getRank();
int64_t padsSize = 2 * dataRank;
Location loc = binder.getLoc();
// Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor.
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
if (!padsTensorType || !padsTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty pad tensor");
}
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
int64_t padsRank = padsShape.size();
if (padsRank != 1)
return rewriter.notifyMatchFailure(binder.op,
"expect 1-d pad tensor");
// get pads (earlier versions use an attribute, newer versions use a
// tensor input)
SmallVector<Value> padsTensorValue;
if (binder.tensorOperandAtIndex(pads, 1)) {
SmallVector<int64_t> defaultPads(2 * dataRank, 0);
SmallVector<int64_t> padInts;
if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads))
return rewriter.notifyMatchFailure(binder.op,
"pads binder failure");
// opset_version 1 uses the attribute name "paddings"
if (padInts == defaultPads) {
SmallVector<int64_t> paddingsInts;
if (binder.s64IntegerArrayAttr(paddingsInts, "paddings",
defaultPads))
return rewriter.notifyMatchFailure(binder.op,
"paddings binder failure");
padInts = paddingsInts;
}
for (auto p : padInts)
padsTensorValue.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(p)));
} else {
// Get pads shape and rank. The pads tensor is expected to be 1-D
// tensor.
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
if (!padsTensorType || !padsTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty pad tensor");
}
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
int64_t padsRank = padsShape.size();
if (padsRank != 1)
return rewriter.notifyMatchFailure(binder.op,
"expect 1-d pad tensor");
if (padsShape[0] != Torch::kUnknownSize) {
// As per onnx.Pad documentation, padSize = 2*num_data_axes
// (if axes param not passed). Need to be updated when adding
// support for `axes` param.
padsSize = padsShape[0];
}
int64_t padsSize = padsShape[0];
if (padsSize == Torch::kUnknownSize) {
// As per onnx.Pad documentation, padSize = 2*num_data_axes
// (if axes param not passed). Need to be updated when adding
// support for `axes` param.
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
TensorType dataTensor = dataOpTy.toBuiltinTensor();
if (!dataTensor || !dataTensor.hasRank())
return rewriter.notifyMatchFailure(
binder.op, "pad length unknown and data operand unranked");
int64_t dataRank = dataTensor.getRank();
padsSize = 2 * dataRank;
// Extract all the values of 1-D pad tensor and create a list of all
// these values as torch.pad op expects pad list.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
SmallVector<int64_t> emptyShape;
Type padsElemType = Torch::ValueTensorType::get(
padsTensorType.getContext(), emptyShape,
padsTensorType.getOptionalDtype());
for (uint32_t i = 0; i < padsSize; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);
padsTensorValue.push_back(selectInt);
}
}
Value constantValue;
if (binder.getNumOperands() >= 3) {
if (binder.getNumOperands() >= 3 && cstMode) {
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
auto constTy =
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
@ -2283,38 +2326,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
}
if (!constantValue) {
if (!constantValue && cstMode) {
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
if (isa<IntegerType>(dataTensorType.getDtype()))
constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
if (isa<FloatType>(dataTensorType.getDtype()))
// Earlier versions used a FLOAT attribute to store the constant
// value. The following will pick up on any non-default value attr if
// provided.
float constantFloat;
if (isa<FloatType>(dataTensorType.getDtype()) &&
!binder.f32FloatAttr(constantFloat, "value", 0.0f))
constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f));
loc, rewriter.getF64FloatAttr(constantFloat));
if (!constantValue)
return rewriter.notifyMatchFailure(
binder.op, "expected integer or float data tensor");
}
// Extract all the values of 1-D pad tensor and create a list of all
// these values as torch.pad op expects pad list.
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
SmallVector<Value> padsTensorValue;
SmallVector<int64_t> emptyShape;
Type padsElemType =
Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape,
padsTensorType.getOptionalDtype());
for (uint32_t i = 0; i < padsSize; ++i) {
Value index = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
auto select = rewriter.create<Torch::AtenSelectIntOp>(
loc, padsElemType, pads, constZero, index);
Value selectInt = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), select);
padsTensorValue.push_back(selectInt);
}
// for modes other than "constant" a value is not required
if (!cstMode)
constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);
// The torch.pad op expects a different arrangement of padding pairs for
// each dimension as compared to the onnx.pad op. Rearrange the pad
@ -2335,6 +2368,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
padsRearrange)
.getResult();
// lowering to AtenConstantPadNdOp directly allows passing any torch
// scalar type for the value, whereas AtenPadOp takes an optional float
// type.
if (cstMode && !isa<Torch::NoneType>(constantValue.getType())) {
rewriter.replaceOpWithNewOp<Torch::AtenConstantPadNdOp>(
binder.op, resultType, data, padsSizeList, constantValue);
return success();
}
// translate a few mismatching mode names ONNX -> Torch
mode = (mode == "edge") ? "replicate" : mode;
mode = (mode == "wrap") ? "circular" : mode;
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
loc, rewriter.getStringAttr(mode));

View File

@ -97,8 +97,12 @@ public:
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
auto dstOriginalDtype =
cast<Torch::ValueTensorType>(op.getType()).getDtype();
Value castedValue =
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType,
std::nullopt, dstOriginalDtype);
Type padType = tensor::PadOp::inferResultType(
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
@ -209,26 +213,38 @@ public:
Value one = getConstant(rewriter, loc, 1, indexType);
Value hDimSizeMinusOne = createSub(hDimSize, one);
Value vDimSizeMinusOne = createSub(vDimSize, one);
SmallVector<Value> allOneStrides(numDims, one);
SmallVector<Value> allOneStridesVal(numDims, one);
SmallVector<OpFoldResult> allOneStrides =
getAsOpFoldResult(allOneStridesVal);
SmallVector<Value> extractOffsetsLT(numDims, zero);
extractOffsetsLT[hDim] = zero;
extractOffsetsLT[vDim] = zero;
SmallVector<Value> extractShapeLR(numDims, one);
extractShapeLR[hDim] = one;
extractShapeLR[vDim] = vDimSize;
SmallVector<Value> extractOffsetsLTVal(numDims, zero);
extractOffsetsLTVal[hDim] = zero;
extractOffsetsLTVal[vDim] = zero;
SmallVector<OpFoldResult> extractOffsetsLT =
getAsOpFoldResult(extractOffsetsLTVal);
SmallVector<Value> extractShapeLRVal(numDims, one);
extractShapeLRVal[hDim] = one;
extractShapeLRVal[vDim] = vDimSize;
SmallVector<OpFoldResult> extractShapeLR =
getAsOpFoldResult(extractShapeLRVal);
SmallVector<Value> extractOffsetsRight(numDims, zero);
extractOffsetsRight[hDim] = hDimSizeMinusOne;
extractOffsetsRight[vDim] = zero;
SmallVector<Value> extractOffsetsRightVal(numDims, zero);
extractOffsetsRightVal[hDim] = hDimSizeMinusOne;
extractOffsetsRightVal[vDim] = zero;
SmallVector<OpFoldResult> extractOffsetsRight =
getAsOpFoldResult(extractOffsetsRightVal);
SmallVector<Value> extractOffsetsBottom(numDims, zero);
extractOffsetsBottom[hDim] = zero;
extractOffsetsBottom[vDim] = vDimSizeMinusOne;
SmallVector<Value> extractOffsetsBottomVal(numDims, zero);
extractOffsetsBottomVal[hDim] = zero;
extractOffsetsBottomVal[vDim] = vDimSizeMinusOne;
SmallVector<OpFoldResult> extractOffsetsBottom =
getAsOpFoldResult(extractOffsetsBottomVal);
SmallVector<Value> extractShapeTB(numDims, one);
extractShapeTB[hDim] = hDimSize;
extractShapeTB[vDim] = one;
SmallVector<Value> extractShapeTBVal(numDims, one);
extractShapeTBVal[hDim] = hDimSize;
extractShapeTBVal[vDim] = one;
SmallVector<OpFoldResult> extractShapeTB =
getAsOpFoldResult(extractShapeTBVal);
SmallVector<Value> tensorsLeft;
SmallVector<Value> tensorsRight;
@ -240,24 +256,26 @@ public:
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
Value vLeftSlice = vCenterLeftSlice;
SmallVector<Value> extractIndices(numDims, zero);
if (hasTopPadding) {
Value topLeftValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, zero, zero});
Value topLeftValue =
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// pad vCenterLeftSlice on the top
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
lowPadding[2] = padInts[2];
SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(numDims, 0);
lowPadding[vDim] = padInts[2];
vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
}
if (hasBottomPadding) {
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
extractIndices[vDim] = vDimSizeMinusOne;
Value bottomLeftValue =
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// pad vLeftSlice at the bottom
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
highPadding[2] = padInts[3];
SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(numDims, 0);
highPadding[vDim] = padInts[3];
vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
}
@ -265,7 +283,7 @@ public:
tensorsLeft.push_back(vLeftSlice);
}
Value leftPadTile =
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft);
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsLeft);
tensorsRes.push_back(leftPadTile);
}
if (hasTopPadding) {
@ -283,33 +301,35 @@ public:
tensorsCenter.push_back(bottomHcenterSlice);
}
}
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter);
centerTile = rewriter.create<tensor::ConcatOp>(loc, vDim, tensorsCenter);
tensorsRes.push_back(centerTile);
if (hasRightPadding) {
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
Value vRightSlice = vCenterRightSlice;
SmallVector<Value> extractIndices(numDims, zero);
extractIndices[hDim] = hDimSizeMinusOne;
if (hasTopPadding) {
Value topRightValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
// pad vCenterRightSlice on the top
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
lowPadding[2] = padInts[2];
SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(numDims, 0);
lowPadding[vDim] = padInts[2];
vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
}
if (hasBottomPadding) {
Value bottomRightValue = rewriter.create<tensor::ExtractOp>(
loc, input,
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
extractIndices[vDim] = vDimSizeMinusOne;
Value bottomRightValue =
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
highPadding[2] = padInts[3];
SmallVector<int64_t> lowPadding(numDims, 0);
SmallVector<int64_t> highPadding(numDims, 0);
highPadding[vDim] = padInts[3];
vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding,
bottomRightValue);
@ -318,10 +338,10 @@ public:
tensorsRight.push_back(vRightSlice);
}
Value rightPadTile =
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRight);
tensorsRes.push_back(rightPadTile);
}
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRes);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
return success();

View File

@ -6379,17 +6379,91 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPadOp op,
PatternRewriter &rewriter) const override {
std::string mode;
if (!matchPattern(op.getMode(), m_TorchConstantStr(mode)))
return rewriter.notifyMatchFailure(op, "mode must be a constant string");
Value value = op.getValue();
if (isa<Torch::OptionalType>(value.getType()))
return rewriter.notifyMatchFailure(op, "optional type not supported");
if (isa<Torch::NoneType>(value.getType()))
value = rewriter.create<Torch::ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0));
if (mode == "constant") {
Value value = op.getValue();
if (isa<Torch::OptionalType>(value.getType()))
return rewriter.notifyMatchFailure(op, "optional type not supported");
if (isa<Torch::NoneType>(value.getType()))
value = rewriter.create<Torch::ConstantFloatOp>(
op.getLoc(), rewriter.getF64FloatAttr(0));
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
op, op.getType(), op.getSelf(), op.getPad(), value);
return success();
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
op, op.getType(), op.getSelf(), op.getPad(), value);
return success();
}
SmallVector<Value> padValues;
if (!getListConstructElements(op.getPad(), padValues))
return failure();
SmallVector<int64_t> padInts;
Value usefulPads = op.getPad();
uint64_t usefulPadIndexEnd = padValues.size();
// try to reduce the number of padding dims if possible
if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) {
if ((padInts.size() % 2) == 1)
return rewriter.notifyMatchFailure(op,
"expected an even number of pads");
for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) {
if (padInts[i] != 0 || padInts[i - 1] != 0)
break;
usefulPadIndexEnd = i - 1;
}
if (usefulPadIndexEnd == 0) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
}
// we don't have support for 1-D replicate pad, so pass it as 2d if
// possible.
// TODO: add support for AtenReplicatePad1dOp and remove this.
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
usefulPadIndexEnd = 4;
// make a new list of padding ints if dimensionality reduction can be
// performed
if (usefulPadIndexEnd < padValues.size()) {
ArrayRef<Value> usefulPadValues(padValues.begin(),
padValues.begin() + usefulPadIndexEnd);
usefulPads = rewriter.create<PrimListConstructOp>(
op.getLoc(),
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
usefulPadValues);
}
uint64_t numPadDims = usefulPadIndexEnd / 2;
if (mode == "reflect") {
// only support for relectionpad 1d and 2d
if (numPadDims == 2) {
rewriter.replaceOpWithNewOp<AtenReflectionPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
if (numPadDims == 1) {
rewriter.replaceOpWithNewOp<AtenReflectionPad1dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
return failure();
}
if (mode == "replicate") {
// only support for replication pad 2d
if (numPadDims != 2)
return failure();
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
return success();
}
return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode);
}
};
} // namespace

View File

@ -40,7 +40,8 @@ bool isQCommutingOp(mlir::Operation *op) {
// if adding a new commuting op here, be sure to add a
// RemoveUnused pattern for that op to clean up afterwards
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
PrimsCollapseOp, AtenViewOp>(op);
PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
op);
}
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
@ -65,7 +66,7 @@ public:
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
Value operand = operands[i];
std::stack<mlir::Operation *> commutingOpStack;
Value dequantOpd, MPTQTOpd;
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
for (unsigned k = 0; k < depth + 1; k++) {
auto currOp = operand.getDefiningOp();
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
@ -84,6 +85,8 @@ public:
auto MPTQTOp =
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
MPTQTOpd = MPTQTOp.getOperand(0);
scale = MPTQTOp.getOperand(1);
zeroPoint = MPTQTOp.getOperand(2);
}
// either a dequant was found or chain broken, so break loop
break;
@ -107,6 +110,47 @@ public:
commutingOpStack.pop();
llvm::SmallVector<Value> currOperands(currOp->getOperands());
currOperands[0] = oldOpd;
// pad ops aren't quite commuting, so we include some extra logic to
// quantize the padding value
if (isa<Torch::AtenPadOp, Torch::AtenConstantPadNdOp>(currOp)) {
Value floatPadValue = currOperands.back();
Value quantPadValue;
if (isa<Torch::NoneType>(floatPadValue.getType()))
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
else {
floatPadValue =
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
loc, floatPadValue, scale);
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
loc, quantPadValue, zeroPoint);
}
// clamp pad value to qint range
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
bool isSigned = intType.isSignedInteger();
int64_t width = intType.getWidth();
assert(width < 64 &&
"quantized int bitwidth should be less than 64");
int64_t minInt = isSigned ? -(1 << (width - 1)) : 0;
int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1);
Value minQValueFloat = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(minInt));
Value maxQValueFloat = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(maxInt));
SmallVector<int64_t> emptyShape;
auto floatTensorType = rewriter.getType<Torch::ValueTensorType>(
emptyShape, rewriter.getF64Type());
Value quantPadValueTensor = createRank0Tensor(
rewriter, loc, floatTensorType, quantPadValue);
Value clampedTensor = rewriter.create<Torch::AtenClampOp>(
loc, floatTensorType, quantPadValueTensor, minQValueFloat,
maxQValueFloat);
quantPadValue = rewriter.create<Torch::AtenItemOp>(
loc, rewriter.getType<Torch::FloatType>(), clampedTensor);
}
// quantPadValue is a float, but will get converted/truncated
currOperands.back() = quantPadValue;
}
// get new result type
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
auto intType =
@ -374,7 +418,8 @@ public:
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
RemoveUnused<AtenViewOp>,
RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
RemoveUnused<AtenConstantPadNdOp>,
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,

View File

@ -1014,6 +1014,38 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>,
// -----
// CHECK-LABEL: @test_conv_with_asymmetric_padding
func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: %[[int2_1:.*]] = torch.constant.int 2
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[str:.*]] = torch.constant.str "constant"
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32>
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
// CHECK: %[[C2:.*]] = torch.constant.int 2
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
// CHECK: %[[BIAS:.*]] = torch.constant.none
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
// CHECK: return %[[Conv]]
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32>
return %0 : !torch.vtensor<[1,1,4,3],f32>
}
// -----
// CHECK-LABEL: @test_conv_with_bias_strides_padding
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C3:.*]] = torch.constant.int 3

View File

@ -815,32 +815,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t
// -----
// CHECK-LABEL: @test_grid_sampler03
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[B0:.*]] = torch.constant.bool true
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
// CHECK-LABEL: func.func @test_oldest_pad
func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} {
// CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_less_or_equal
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
%0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
return %0 : !torch.vtensor<[3,4,5],i1>
// CHECK-LABEL: func.func @test_old_pad
func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} {
// CHECK: %[[int0:.*]] = torch.constant.int 0
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
// CHECK: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_pad
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
@ -854,9 +862,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
@ -864,12 +872,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// -----
// CHECK-LABEL: func.func @test_i32pad
func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list<int>, !torch.int -> !torch.vtensor<[5,4],si32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32>
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32>
return %0 : !torch.vtensor<[5,4],si32>
}
// -----
// CHECK-LABEL: @test_pad_optional_constant
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.float 0
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
@ -878,6 +910,34 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !
// -----
// CHECK-LABEL: @test_pad_wrap
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "circular"
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: @test_pad_edge
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
// CHECK: %[[VAL:.+]] = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "replicate"
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
return %0 : !torch.vtensor<[5,4],f32>
}
// -----
// CHECK-LABEL: func.func @test_pow
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>

View File

@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
// -----
// CHECK-LABEL: func.func @mm_pad_commute
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
// CHECK-DAG: %[[none:.*]] = torch.constant.none
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
%scale = torch.constant.float 0.5
%false = torch.constant.bool false
%zero = torch.constant.int 0
%one = torch.constant.int 1
%two = torch.constant.int 2
%floatpad = torch.constant.float 3.5
%zp = torch.constant.int -128
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8>
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32>
%list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%str = torch.constant.str "constant"
%pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32>
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32>
%16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32>
return %16 : !torch.vtensor<[9,4],f32>
}
// -----
// CHECK-LABEL: @convolution_bias
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
%scale = torch.constant.float 0.5