Remove dynamic padding.

cuda_f16
Prashant Kumar 2022-12-13 09:46:10 +00:00
parent 1eccac1264
commit 44b6cb0973
1 changed files with 15 additions and 9 deletions

View File

@ -103,7 +103,8 @@ public:
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
auto selfRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
Type elementType =
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
Value c1 =
@ -506,18 +507,18 @@ public:
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
};
SmallVector<Value> paddingIntValues;
if (!getListConstructElements(op.getPadding(), paddingIntValues))
SmallVector<int64_t> paddingInts;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) {
return rewriter.notifyMatchFailure(
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
op, "only support constant padding values");
}
SmallVector<int64_t> strideInts;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t> dilationInts;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
@ -554,6 +555,8 @@ public:
"invalid: groups must divide weight batch size evenly.");
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
@ -651,8 +654,11 @@ public:
} else {
// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
paddingIncludingNC);
// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)