mirror of https://github.com/llvm/torch-mlir
Adds misc fixes for some padding related issues (#3528)
This patch adds a few misc pad op related changes: 1. Addresses issue <https://github.com/llvm/torch-mlir/issues/3457> 2. Addresses issue <https://github.com/llvm/torch-mlir/issues/3442> 3. Fixes the padding order for asymmetrically padded onnx.Conv ops 4. Enables passing quantization through those onnx.Conv op pre-paddings 5. Modifies the torch-to-linalg lowering of AtenReplicationPad2d op to enable support for input rank != 4 Unfortunately, even with all of these changes, the e2e tests for the ReplicationPad2d still fail the onnx config, since the torch export procedure for rearranging the pad order is complicated enough that the padding ints end up not being able to fold back to constants.pull/3541/head
parent
b38585e077
commit
0fb8b017d8
|
@ -1343,11 +1343,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
SmallVector<Value> padsRearrange;
|
SmallVector<Value> padsRearrange;
|
||||||
SmallVector<Value> inputPaddingList;
|
SmallVector<Value> inputPaddingList;
|
||||||
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
for (uint32_t i = 0; i < padding.size() / 2; i++) {
|
||||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
|
|
||||||
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(
|
binder.getLoc(), rewriter.getI64IntegerAttr(
|
||||||
padding[(padding.size() / 2) + i])));
|
padding[padding.size() / 2 - i - 1])));
|
||||||
|
padsRearrange.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
rewriter.getI64IntegerAttr(padding[padding.size() - i - 1])));
|
||||||
inputPaddingList.emplace_back(
|
inputPaddingList.emplace_back(
|
||||||
rewriter.create<Torch::ConstantIntOp>(
|
rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
binder.getLoc(), rewriter.getI64IntegerAttr(0)));
|
||||||
|
|
|
@ -2233,41 +2233,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, "The axes parameter is not supported yet");
|
binder.op, "The axes parameter is not supported yet");
|
||||||
}
|
}
|
||||||
if (binder.tensorOperandAtIndex(data, 0) ||
|
if (binder.tensorOperandAtIndex(data, 0) ||
|
||||||
binder.tensorOperandAtIndex(pads, 1) ||
|
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType) ||
|
||||||
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
binder.customOpNameStringAttr(mode, "mode", "constant"))
|
||||||
return failure();
|
return failure();
|
||||||
|
bool cstMode = (mode == "constant");
|
||||||
|
|
||||||
|
// get input rank
|
||||||
|
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
||||||
|
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
||||||
|
if (!dataTensor || !dataTensor.hasRank())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "pad length unknown and data operand unranked");
|
||||||
|
int64_t dataRank = dataTensor.getRank();
|
||||||
|
int64_t padsSize = 2 * dataRank;
|
||||||
|
|
||||||
Location loc = binder.getLoc();
|
Location loc = binder.getLoc();
|
||||||
|
|
||||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
// get pads (earlier versions use an attribute, newer versions use a
|
||||||
// tensor.
|
// tensor input)
|
||||||
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
SmallVector<Value> padsTensorValue;
|
||||||
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
if (binder.tensorOperandAtIndex(pads, 1)) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
SmallVector<int64_t> defaultPads(2 * dataRank, 0);
|
||||||
"Expect non empty pad tensor");
|
SmallVector<int64_t> padInts;
|
||||||
}
|
if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads))
|
||||||
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
int64_t padsRank = padsShape.size();
|
"pads binder failure");
|
||||||
if (padsRank != 1)
|
// opset_version 1 uses the attribute name "paddings"
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
if (padInts == defaultPads) {
|
||||||
"expect 1-d pad tensor");
|
SmallVector<int64_t> paddingsInts;
|
||||||
|
if (binder.s64IntegerArrayAttr(paddingsInts, "paddings",
|
||||||
|
defaultPads))
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"paddings binder failure");
|
||||||
|
padInts = paddingsInts;
|
||||||
|
}
|
||||||
|
for (auto p : padInts)
|
||||||
|
padsTensorValue.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(p)));
|
||||||
|
} else {
|
||||||
|
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||||
|
// tensor.
|
||||||
|
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
||||||
|
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"Expect non empty pad tensor");
|
||||||
|
}
|
||||||
|
ArrayRef<int64_t> padsShape = padsTensorType.getSizes();
|
||||||
|
int64_t padsRank = padsShape.size();
|
||||||
|
if (padsRank != 1)
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"expect 1-d pad tensor");
|
||||||
|
if (padsShape[0] != Torch::kUnknownSize) {
|
||||||
|
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
||||||
|
// (if axes param not passed). Need to be updated when adding
|
||||||
|
// support for `axes` param.
|
||||||
|
padsSize = padsShape[0];
|
||||||
|
}
|
||||||
|
|
||||||
int64_t padsSize = padsShape[0];
|
// Extract all the values of 1-D pad tensor and create a list of all
|
||||||
if (padsSize == Torch::kUnknownSize) {
|
// these values as torch.pad op expects pad list.
|
||||||
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
// (if axes param not passed). Need to be updated when adding
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
// support for `axes` param.
|
SmallVector<int64_t> emptyShape;
|
||||||
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
Type padsElemType = Torch::ValueTensorType::get(
|
||||||
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
padsTensorType.getContext(), emptyShape,
|
||||||
if (!dataTensor || !dataTensor.hasRank())
|
padsTensorType.getOptionalDtype());
|
||||||
return rewriter.notifyMatchFailure(
|
for (uint32_t i = 0; i < padsSize; ++i) {
|
||||||
binder.op, "pad length unknown and data operand unranked");
|
Value index = rewriter.create<Torch::ConstantIntOp>(
|
||||||
int64_t dataRank = dataTensor.getRank();
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
padsSize = 2 * dataRank;
|
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
loc, padsElemType, pads, constZero, index);
|
||||||
|
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
loc, rewriter.getType<Torch::IntType>(), select);
|
||||||
|
padsTensorValue.push_back(selectInt);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value constantValue;
|
Value constantValue;
|
||||||
if (binder.getNumOperands() >= 3) {
|
if (binder.getNumOperands() >= 3 && cstMode) {
|
||||||
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
|
if (!binder.tensorOperandAtIndex(constantValue, 2)) {
|
||||||
auto constTy =
|
auto constTy =
|
||||||
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
|
dyn_cast<Torch::BaseTensorType>(constantValue.getType());
|
||||||
|
@ -2283,38 +2326,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!constantValue) {
|
if (!constantValue && cstMode) {
|
||||||
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (isa<IntegerType>(dataTensorType.getDtype()))
|
if (isa<IntegerType>(dataTensorType.getDtype()))
|
||||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
if (isa<FloatType>(dataTensorType.getDtype()))
|
// Earlier versions used a FLOAT attribute to store the constant
|
||||||
|
// value. The following will pick up on any non-default value attr if
|
||||||
|
// provided.
|
||||||
|
float constantFloat;
|
||||||
|
if (isa<FloatType>(dataTensorType.getDtype()) &&
|
||||||
|
!binder.f32FloatAttr(constantFloat, "value", 0.0f))
|
||||||
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
constantValue = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
loc, rewriter.getF64FloatAttr(0.0f));
|
loc, rewriter.getF64FloatAttr(constantFloat));
|
||||||
|
|
||||||
if (!constantValue)
|
if (!constantValue)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "expected integer or float data tensor");
|
binder.op, "expected integer or float data tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract all the values of 1-D pad tensor and create a list of all
|
// for modes other than "constant" a value is not required
|
||||||
// these values as torch.pad op expects pad list.
|
if (!cstMode)
|
||||||
Value constZero = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
|
||||||
SmallVector<Value> padsTensorValue;
|
|
||||||
SmallVector<int64_t> emptyShape;
|
|
||||||
Type padsElemType =
|
|
||||||
Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape,
|
|
||||||
padsTensorType.getOptionalDtype());
|
|
||||||
for (uint32_t i = 0; i < padsSize; ++i) {
|
|
||||||
Value index = rewriter.create<Torch::ConstantIntOp>(
|
|
||||||
loc, rewriter.getI64IntegerAttr(i));
|
|
||||||
auto select = rewriter.create<Torch::AtenSelectIntOp>(
|
|
||||||
loc, padsElemType, pads, constZero, index);
|
|
||||||
Value selectInt = rewriter.create<Torch::AtenItemOp>(
|
|
||||||
loc, rewriter.getType<Torch::IntType>(), select);
|
|
||||||
padsTensorValue.push_back(selectInt);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The torch.pad op expects a different arrangement of padding pairs for
|
// The torch.pad op expects a different arrangement of padding pairs for
|
||||||
// each dimension as compared to the onnx.pad op. Rearrange the pad
|
// each dimension as compared to the onnx.pad op. Rearrange the pad
|
||||||
|
@ -2335,6 +2368,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
Torch::ListType::get(rewriter.getType<Torch::IntType>()),
|
||||||
padsRearrange)
|
padsRearrange)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
|
// lowering to AtenConstantPadNdOp directly allows passing any torch
|
||||||
|
// scalar type for the value, whereas AtenPadOp takes an optional float
|
||||||
|
// type.
|
||||||
|
if (cstMode && !isa<Torch::NoneType>(constantValue.getType())) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenConstantPadNdOp>(
|
||||||
|
binder.op, resultType, data, padsSizeList, constantValue);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// translate a few mismatching mode names ONNX -> Torch
|
||||||
|
mode = (mode == "edge") ? "replicate" : mode;
|
||||||
|
mode = (mode == "wrap") ? "circular" : mode;
|
||||||
|
|
||||||
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
Value modeVal = rewriter.create<Torch::ConstantStrOp>(
|
||||||
loc, rewriter.getStringAttr(mode));
|
loc, rewriter.getStringAttr(mode));
|
||||||
|
|
||||||
|
|
|
@ -97,8 +97,12 @@ public:
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
|
Type elementType = cast<RankedTensorType>(newResultType).getElementType();
|
||||||
|
|
||||||
|
auto dstOriginalDtype =
|
||||||
|
cast<Torch::ValueTensorType>(op.getType()).getDtype();
|
||||||
Value castedValue =
|
Value castedValue =
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType,
|
||||||
|
std::nullopt, dstOriginalDtype);
|
||||||
|
|
||||||
Type padType = tensor::PadOp::inferResultType(
|
Type padType = tensor::PadOp::inferResultType(
|
||||||
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
|
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
|
||||||
|
@ -209,26 +213,38 @@ public:
|
||||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||||
Value hDimSizeMinusOne = createSub(hDimSize, one);
|
Value hDimSizeMinusOne = createSub(hDimSize, one);
|
||||||
Value vDimSizeMinusOne = createSub(vDimSize, one);
|
Value vDimSizeMinusOne = createSub(vDimSize, one);
|
||||||
SmallVector<Value> allOneStrides(numDims, one);
|
SmallVector<Value> allOneStridesVal(numDims, one);
|
||||||
|
SmallVector<OpFoldResult> allOneStrides =
|
||||||
|
getAsOpFoldResult(allOneStridesVal);
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsLT(numDims, zero);
|
SmallVector<Value> extractOffsetsLTVal(numDims, zero);
|
||||||
extractOffsetsLT[hDim] = zero;
|
extractOffsetsLTVal[hDim] = zero;
|
||||||
extractOffsetsLT[vDim] = zero;
|
extractOffsetsLTVal[vDim] = zero;
|
||||||
SmallVector<Value> extractShapeLR(numDims, one);
|
SmallVector<OpFoldResult> extractOffsetsLT =
|
||||||
extractShapeLR[hDim] = one;
|
getAsOpFoldResult(extractOffsetsLTVal);
|
||||||
extractShapeLR[vDim] = vDimSize;
|
SmallVector<Value> extractShapeLRVal(numDims, one);
|
||||||
|
extractShapeLRVal[hDim] = one;
|
||||||
|
extractShapeLRVal[vDim] = vDimSize;
|
||||||
|
SmallVector<OpFoldResult> extractShapeLR =
|
||||||
|
getAsOpFoldResult(extractShapeLRVal);
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsRight(numDims, zero);
|
SmallVector<Value> extractOffsetsRightVal(numDims, zero);
|
||||||
extractOffsetsRight[hDim] = hDimSizeMinusOne;
|
extractOffsetsRightVal[hDim] = hDimSizeMinusOne;
|
||||||
extractOffsetsRight[vDim] = zero;
|
extractOffsetsRightVal[vDim] = zero;
|
||||||
|
SmallVector<OpFoldResult> extractOffsetsRight =
|
||||||
|
getAsOpFoldResult(extractOffsetsRightVal);
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsBottom(numDims, zero);
|
SmallVector<Value> extractOffsetsBottomVal(numDims, zero);
|
||||||
extractOffsetsBottom[hDim] = zero;
|
extractOffsetsBottomVal[hDim] = zero;
|
||||||
extractOffsetsBottom[vDim] = vDimSizeMinusOne;
|
extractOffsetsBottomVal[vDim] = vDimSizeMinusOne;
|
||||||
|
SmallVector<OpFoldResult> extractOffsetsBottom =
|
||||||
|
getAsOpFoldResult(extractOffsetsBottomVal);
|
||||||
|
|
||||||
SmallVector<Value> extractShapeTB(numDims, one);
|
SmallVector<Value> extractShapeTBVal(numDims, one);
|
||||||
extractShapeTB[hDim] = hDimSize;
|
extractShapeTBVal[hDim] = hDimSize;
|
||||||
extractShapeTB[vDim] = one;
|
extractShapeTBVal[vDim] = one;
|
||||||
|
SmallVector<OpFoldResult> extractShapeTB =
|
||||||
|
getAsOpFoldResult(extractShapeTBVal);
|
||||||
|
|
||||||
SmallVector<Value> tensorsLeft;
|
SmallVector<Value> tensorsLeft;
|
||||||
SmallVector<Value> tensorsRight;
|
SmallVector<Value> tensorsRight;
|
||||||
|
@ -240,24 +256,26 @@ public:
|
||||||
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
|
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
|
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
|
||||||
Value vLeftSlice = vCenterLeftSlice;
|
Value vLeftSlice = vCenterLeftSlice;
|
||||||
|
SmallVector<Value> extractIndices(numDims, zero);
|
||||||
if (hasTopPadding) {
|
if (hasTopPadding) {
|
||||||
Value topLeftValue = rewriter.create<tensor::ExtractOp>(
|
Value topLeftValue =
|
||||||
loc, input, ValueRange{zero, zero, zero, zero});
|
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||||
// pad vCenterLeftSlice on the top
|
// pad vCenterLeftSlice on the top
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
SmallVector<int64_t> highPadding(numDims, 0);
|
||||||
lowPadding[2] = padInts[2];
|
lowPadding[vDim] = padInts[2];
|
||||||
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||||
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
|
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
|
||||||
}
|
}
|
||||||
if (hasBottomPadding) {
|
if (hasBottomPadding) {
|
||||||
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>(
|
extractIndices[vDim] = vDimSizeMinusOne;
|
||||||
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
|
Value bottomLeftValue =
|
||||||
|
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||||
|
|
||||||
// pad vLeftSlice at the bottom
|
// pad vLeftSlice at the bottom
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
SmallVector<int64_t> highPadding(numDims, 0);
|
||||||
highPadding[2] = padInts[3];
|
highPadding[vDim] = padInts[3];
|
||||||
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||||
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
|
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
|
||||||
}
|
}
|
||||||
|
@ -265,7 +283,7 @@ public:
|
||||||
tensorsLeft.push_back(vLeftSlice);
|
tensorsLeft.push_back(vLeftSlice);
|
||||||
}
|
}
|
||||||
Value leftPadTile =
|
Value leftPadTile =
|
||||||
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft);
|
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsLeft);
|
||||||
tensorsRes.push_back(leftPadTile);
|
tensorsRes.push_back(leftPadTile);
|
||||||
}
|
}
|
||||||
if (hasTopPadding) {
|
if (hasTopPadding) {
|
||||||
|
@ -283,33 +301,35 @@ public:
|
||||||
tensorsCenter.push_back(bottomHcenterSlice);
|
tensorsCenter.push_back(bottomHcenterSlice);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter);
|
centerTile = rewriter.create<tensor::ConcatOp>(loc, vDim, tensorsCenter);
|
||||||
tensorsRes.push_back(centerTile);
|
tensorsRes.push_back(centerTile);
|
||||||
|
|
||||||
if (hasRightPadding) {
|
if (hasRightPadding) {
|
||||||
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
|
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
|
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
|
||||||
Value vRightSlice = vCenterRightSlice;
|
Value vRightSlice = vCenterRightSlice;
|
||||||
|
SmallVector<Value> extractIndices(numDims, zero);
|
||||||
|
extractIndices[hDim] = hDimSizeMinusOne;
|
||||||
if (hasTopPadding) {
|
if (hasTopPadding) {
|
||||||
Value topRightValue = rewriter.create<tensor::ExtractOp>(
|
Value topRightValue = rewriter.create<tensor::ExtractOp>(
|
||||||
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
|
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
|
||||||
|
|
||||||
// pad vCenterRightSlice on the top
|
// pad vCenterRightSlice on the top
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
SmallVector<int64_t> highPadding(numDims, 0);
|
||||||
lowPadding[2] = padInts[2];
|
lowPadding[vDim] = padInts[2];
|
||||||
vRightSlice = torch_to_linalg::getPaddedTensor(
|
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||||
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
|
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
|
||||||
}
|
}
|
||||||
if (hasBottomPadding) {
|
if (hasBottomPadding) {
|
||||||
Value bottomRightValue = rewriter.create<tensor::ExtractOp>(
|
extractIndices[vDim] = vDimSizeMinusOne;
|
||||||
loc, input,
|
Value bottomRightValue =
|
||||||
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
|
rewriter.create<tensor::ExtractOp>(loc, input, extractIndices);
|
||||||
|
|
||||||
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
|
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
SmallVector<int64_t> lowPadding(numDims, 0);
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
SmallVector<int64_t> highPadding(numDims, 0);
|
||||||
highPadding[2] = padInts[3];
|
highPadding[vDim] = padInts[3];
|
||||||
vRightSlice = torch_to_linalg::getPaddedTensor(
|
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||||
op, rewriter, vRightSlice, lowPadding, highPadding,
|
op, rewriter, vRightSlice, lowPadding, highPadding,
|
||||||
bottomRightValue);
|
bottomRightValue);
|
||||||
|
@ -318,10 +338,10 @@ public:
|
||||||
tensorsRight.push_back(vRightSlice);
|
tensorsRight.push_back(vRightSlice);
|
||||||
}
|
}
|
||||||
Value rightPadTile =
|
Value rightPadTile =
|
||||||
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
|
rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRight);
|
||||||
tensorsRes.push_back(rightPadTile);
|
tensorsRes.push_back(rightPadTile);
|
||||||
}
|
}
|
||||||
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
|
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, hDim, tensorsRes);
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -6379,17 +6379,91 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenPadOp op,
|
LogicalResult matchAndRewrite(AtenPadOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
std::string mode;
|
||||||
|
if (!matchPattern(op.getMode(), m_TorchConstantStr(mode)))
|
||||||
|
return rewriter.notifyMatchFailure(op, "mode must be a constant string");
|
||||||
|
|
||||||
Value value = op.getValue();
|
if (mode == "constant") {
|
||||||
if (isa<Torch::OptionalType>(value.getType()))
|
Value value = op.getValue();
|
||||||
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
if (isa<Torch::OptionalType>(value.getType()))
|
||||||
if (isa<Torch::NoneType>(value.getType()))
|
return rewriter.notifyMatchFailure(op, "optional type not supported");
|
||||||
value = rewriter.create<Torch::ConstantFloatOp>(
|
if (isa<Torch::NoneType>(value.getType()))
|
||||||
op.getLoc(), rewriter.getF64FloatAttr(0));
|
value = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
op.getLoc(), rewriter.getF64FloatAttr(0));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
rewriter.replaceOpWithNewOp<AtenConstantPadNdOp>(
|
||||||
op, op.getType(), op.getSelf(), op.getPad(), value);
|
op, op.getType(), op.getSelf(), op.getPad(), value);
|
||||||
return success();
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> padValues;
|
||||||
|
if (!getListConstructElements(op.getPad(), padValues))
|
||||||
|
return failure();
|
||||||
|
SmallVector<int64_t> padInts;
|
||||||
|
Value usefulPads = op.getPad();
|
||||||
|
uint64_t usefulPadIndexEnd = padValues.size();
|
||||||
|
|
||||||
|
// try to reduce the number of padding dims if possible
|
||||||
|
if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) {
|
||||||
|
if ((padInts.size() % 2) == 1)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"expected an even number of pads");
|
||||||
|
|
||||||
|
for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) {
|
||||||
|
if (padInts[i] != 0 || padInts[i - 1] != 0)
|
||||||
|
break;
|
||||||
|
usefulPadIndexEnd = i - 1;
|
||||||
|
}
|
||||||
|
if (usefulPadIndexEnd == 0) {
|
||||||
|
rewriter.replaceOp(op, op.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't have support for 1-D replicate pad, so pass it as 2d if
|
||||||
|
// possible.
|
||||||
|
// TODO: add support for AtenReplicatePad1dOp and remove this.
|
||||||
|
if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4)
|
||||||
|
usefulPadIndexEnd = 4;
|
||||||
|
|
||||||
|
// make a new list of padding ints if dimensionality reduction can be
|
||||||
|
// performed
|
||||||
|
if (usefulPadIndexEnd < padValues.size()) {
|
||||||
|
ArrayRef<Value> usefulPadValues(padValues.begin(),
|
||||||
|
padValues.begin() + usefulPadIndexEnd);
|
||||||
|
usefulPads = rewriter.create<PrimListConstructOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
|
||||||
|
usefulPadValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t numPadDims = usefulPadIndexEnd / 2;
|
||||||
|
|
||||||
|
if (mode == "reflect") {
|
||||||
|
// only support for relectionpad 1d and 2d
|
||||||
|
if (numPadDims == 2) {
|
||||||
|
rewriter.replaceOpWithNewOp<AtenReflectionPad2dOp>(
|
||||||
|
op, op.getType(), op.getSelf(), usefulPads);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (numPadDims == 1) {
|
||||||
|
rewriter.replaceOpWithNewOp<AtenReflectionPad1dOp>(
|
||||||
|
op, op.getType(), op.getSelf(), usefulPads);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode == "replicate") {
|
||||||
|
// only support for replication pad 2d
|
||||||
|
if (numPadDims != 2)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
|
||||||
|
op, op.getType(), op.getSelf(), usefulPads);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -40,7 +40,8 @@ bool isQCommutingOp(mlir::Operation *op) {
|
||||||
// if adding a new commuting op here, be sure to add a
|
// if adding a new commuting op here, be sure to add a
|
||||||
// RemoveUnused pattern for that op to clean up afterwards
|
// RemoveUnused pattern for that op to clean up afterwards
|
||||||
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
|
return llvm::isa<AtenTransposeIntOp, AtenReshapeOp, AtenSliceTensorOp,
|
||||||
PrimsCollapseOp, AtenViewOp>(op);
|
PrimsCollapseOp, AtenViewOp, AtenPadOp, AtenConstantPadNdOp>(
|
||||||
|
op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
|
// The following conversion takes patterns of the form [op0 -> MPTQT -> dequant
|
||||||
|
@ -65,7 +66,7 @@ public:
|
||||||
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
|
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
|
||||||
Value operand = operands[i];
|
Value operand = operands[i];
|
||||||
std::stack<mlir::Operation *> commutingOpStack;
|
std::stack<mlir::Operation *> commutingOpStack;
|
||||||
Value dequantOpd, MPTQTOpd;
|
Value dequantOpd, MPTQTOpd, scale, zeroPoint;
|
||||||
for (unsigned k = 0; k < depth + 1; k++) {
|
for (unsigned k = 0; k < depth + 1; k++) {
|
||||||
auto currOp = operand.getDefiningOp();
|
auto currOp = operand.getDefiningOp();
|
||||||
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
|
// Case 0 : currOp is a nullptr (e.g., operand is a block argument)
|
||||||
|
@ -84,6 +85,8 @@ public:
|
||||||
auto MPTQTOp =
|
auto MPTQTOp =
|
||||||
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
|
dequantOpd.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>();
|
||||||
MPTQTOpd = MPTQTOp.getOperand(0);
|
MPTQTOpd = MPTQTOp.getOperand(0);
|
||||||
|
scale = MPTQTOp.getOperand(1);
|
||||||
|
zeroPoint = MPTQTOp.getOperand(2);
|
||||||
}
|
}
|
||||||
// either a dequant was found or chain broken, so break loop
|
// either a dequant was found or chain broken, so break loop
|
||||||
break;
|
break;
|
||||||
|
@ -107,6 +110,47 @@ public:
|
||||||
commutingOpStack.pop();
|
commutingOpStack.pop();
|
||||||
llvm::SmallVector<Value> currOperands(currOp->getOperands());
|
llvm::SmallVector<Value> currOperands(currOp->getOperands());
|
||||||
currOperands[0] = oldOpd;
|
currOperands[0] = oldOpd;
|
||||||
|
// pad ops aren't quite commuting, so we include some extra logic to
|
||||||
|
// quantize the padding value
|
||||||
|
if (isa<Torch::AtenPadOp, Torch::AtenConstantPadNdOp>(currOp)) {
|
||||||
|
Value floatPadValue = currOperands.back();
|
||||||
|
Value quantPadValue;
|
||||||
|
if (isa<Torch::NoneType>(floatPadValue.getType()))
|
||||||
|
quantPadValue = rewriter.create<AtenFloatScalarOp>(loc, zeroPoint);
|
||||||
|
else {
|
||||||
|
floatPadValue =
|
||||||
|
rewriter.create<AtenFloatScalarOp>(loc, floatPadValue);
|
||||||
|
quantPadValue = rewriter.create<Torch::AtenDivFloatOp>(
|
||||||
|
loc, floatPadValue, scale);
|
||||||
|
quantPadValue = rewriter.create<Torch::AtenAddFloatIntOp>(
|
||||||
|
loc, quantPadValue, zeroPoint);
|
||||||
|
}
|
||||||
|
// clamp pad value to qint range
|
||||||
|
if (auto intType = dyn_cast<mlir::IntegerType>(intDType)) {
|
||||||
|
bool isSigned = intType.isSignedInteger();
|
||||||
|
int64_t width = intType.getWidth();
|
||||||
|
assert(width < 64 &&
|
||||||
|
"quantized int bitwidth should be less than 64");
|
||||||
|
int64_t minInt = isSigned ? -(1 << (width - 1)) : 0;
|
||||||
|
int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1);
|
||||||
|
Value minQValueFloat = rewriter.create<ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(minInt));
|
||||||
|
Value maxQValueFloat = rewriter.create<ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(maxInt));
|
||||||
|
SmallVector<int64_t> emptyShape;
|
||||||
|
auto floatTensorType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
|
emptyShape, rewriter.getF64Type());
|
||||||
|
Value quantPadValueTensor = createRank0Tensor(
|
||||||
|
rewriter, loc, floatTensorType, quantPadValue);
|
||||||
|
Value clampedTensor = rewriter.create<Torch::AtenClampOp>(
|
||||||
|
loc, floatTensorType, quantPadValueTensor, minQValueFloat,
|
||||||
|
maxQValueFloat);
|
||||||
|
quantPadValue = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
loc, rewriter.getType<Torch::FloatType>(), clampedTensor);
|
||||||
|
}
|
||||||
|
// quantPadValue is a float, but will get converted/truncated
|
||||||
|
currOperands.back() = quantPadValue;
|
||||||
|
}
|
||||||
// get new result type
|
// get new result type
|
||||||
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
|
auto oldType = cast<ValueTensorType>(currOp->getResultTypes()[0]);
|
||||||
auto intType =
|
auto intType =
|
||||||
|
@ -374,7 +418,8 @@ public:
|
||||||
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
|
||||||
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
|
RemoveUnused<AtenTransposeIntOp>, RemoveUnused<AtenSliceTensorOp>,
|
||||||
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
|
RemoveUnused<AtenReshapeOp>, RemoveUnused<PrimsCollapseOp>,
|
||||||
RemoveUnused<AtenViewOp>,
|
RemoveUnused<AtenViewOp>, RemoveUnused<AtenPadOp>,
|
||||||
|
RemoveUnused<AtenConstantPadNdOp>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
|
||||||
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
|
||||||
|
|
|
@ -1014,6 +1014,38 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_conv_with_asymmetric_padding
|
||||||
|
func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int2_1:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[int0_2:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[int0_3:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[str:.*]] = torch.constant.str "constant"
|
||||||
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000
|
||||||
|
// CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32>
|
||||||
|
// CHECK: %[[C1_1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C1_2:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[C2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C2_0:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[BIAS:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[GROUPS:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
// CHECK: return %[[Conv]]
|
||||||
|
%0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
return %0 : !torch.vtensor<[1,1,4,3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_conv_with_bias_strides_padding
|
// CHECK-LABEL: @test_conv_with_bias_strides_padding
|
||||||
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
// CHECK: %[[C3:.*]] = torch.constant.int 3
|
||||||
|
|
|
@ -815,32 +815,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_grid_sampler03
|
// CHECK-LABEL: func.func @test_oldest_pad
|
||||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} {
|
||||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[B0:.*]] = torch.constant.bool true
|
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||||
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
|
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||||
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||||
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
|
||||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
|
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_less_or_equal
|
// CHECK-LABEL: func.func @test_old_pad
|
||||||
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} {
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
// CHECK: %[[int0:.*]] = torch.constant.int 0
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
// CHECK: %[[int0_0:.*]] = torch.constant.int 0
|
||||||
// CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1>
|
// CHECK: %[[int2:.*]] = torch.constant.int 2
|
||||||
%0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1>
|
// CHECK: %[[int0_1:.*]] = torch.constant.int 0
|
||||||
return %0 : !torch.vtensor<[3,4,5],i1>
|
// CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
|
// CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32>
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_pad
|
// CHECK-LABEL: func.func @test_pad
|
||||||
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
|
||||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
@ -854,9 +862,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
||||||
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float
|
||||||
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
|
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
|
||||||
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
|
||||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32>
|
||||||
return %0 : !torch.vtensor<[5,4],f32>
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
|
@ -864,12 +872,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_i32pad
|
||||||
|
func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[INT0_0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.+]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[INT3:.+]] = torch.constant.int 3
|
||||||
|
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
|
||||||
|
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int
|
||||||
|
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list<int>, !torch.int -> !torch.vtensor<[5,4],si32>
|
||||||
|
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32>
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_pad_optional_constant
|
// CHECK-LABEL: @test_pad_optional_constant
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||||
// CHECK: %[[VAL:.+]] = torch.constant.float 0
|
// CHECK: %[[VAL:.+]] = torch.constant.float 0
|
||||||
// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant"
|
// CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.float -> !torch.vtensor<[5,4],f32>
|
||||||
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
|
|
||||||
|
|
||||||
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
@ -878,6 +910,34 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pad_wrap
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[STR:.+]] = torch.constant.str "circular"
|
||||||
|
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
|
||||||
|
|
||||||
|
func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_pad_edge
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32>
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>
|
||||||
|
// CHECK: %[[VAL:.+]] = torch.constant.none
|
||||||
|
// CHECK: %[[STR:.+]] = torch.constant.str "replicate"
|
||||||
|
// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32>
|
||||||
|
|
||||||
|
func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} {
|
||||||
|
%0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32>
|
||||||
|
return %0 : !torch.vtensor<[5,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_pow
|
// CHECK-LABEL: func.func @test_pow
|
||||||
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
// CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
|
|
@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @mm_pad_commute
|
||||||
|
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
|
||||||
|
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
|
||||||
|
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
|
||||||
|
// CHECK-DAG: %[[none:.*]] = torch.constant.none
|
||||||
|
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
|
||||||
|
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
|
||||||
|
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
|
||||||
|
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
|
||||||
|
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
|
||||||
|
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
|
||||||
|
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
|
||||||
|
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
|
||||||
|
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
|
||||||
|
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
|
||||||
|
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
|
||||||
|
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
|
||||||
|
%scale = torch.constant.float 0.5
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%zero = torch.constant.int 0
|
||||||
|
%one = torch.constant.int 1
|
||||||
|
%two = torch.constant.int 2
|
||||||
|
%floatpad = torch.constant.float 3.5
|
||||||
|
%zp = torch.constant.int -128
|
||||||
|
%6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8>
|
||||||
|
%7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32>
|
||||||
|
%list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%str = torch.constant.str "constant"
|
||||||
|
%pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32>
|
||||||
|
%12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
|
||||||
|
%13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32>
|
||||||
|
%16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32>
|
||||||
|
return %16 : !torch.vtensor<[9,4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @convolution_bias
|
// CHECK-LABEL: @convolution_bias
|
||||||
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> {
|
||||||
%scale = torch.constant.float 0.5
|
%scale = torch.constant.float 0.5
|
||||||
|
|
Loading…
Reference in New Issue