Remove dynamic padding.

cuda_f16
Prashant Kumar 2022-12-13 09:46:10 +00:00
parent 1eccac1264
commit 44b6cb0973
1 changed files with 15 additions and 9 deletions

View File

@ -103,7 +103,8 @@ public:
Location loc = op->getLoc(); Location loc = op->getLoc();
MLIRContext *context = op.getContext(); MLIRContext *context = op.getContext();
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank(); auto selfRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
Type elementType = Type elementType =
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType(); adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
Value c1 = Value c1 =
@ -506,18 +507,18 @@ public:
return rewriter.create<arith::IndexCastOp>(loc, intType, v); return rewriter.create<arith::IndexCastOp>(loc, intType, v);
}; };
SmallVector<Value> paddingIntValues; SmallVector<int64_t> paddingInts;
if (!getListConstructElements(op.getPadding(), paddingIntValues)) if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only support padding from a list construct"); op, "only support constant padding values");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), }
paddingIntValues);
SmallVector<int64_t> strideInts; SmallVector<int64_t> strideInts;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int strides"); "only support constant int strides");
SmallVector<int64_t> dilationInts; SmallVector<int64_t> dilationInts;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int dilations"); "only support constant int dilations");
@ -554,6 +555,8 @@ public:
"invalid: groups must divide weight batch size evenly."); "invalid: groups must divide weight batch size evenly.");
SmallVector<Value> dilationIntValues = SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts); getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> strideIntValues = SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts); getAsConstantIntValues(rewriter, loc, strideInts);
@ -651,8 +654,11 @@ public:
} else { } else {
// Pad input // Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2); paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
paddingIncludingNC);
// Calculate output dims // Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++) for (size_t i = 0; i < numSpacialDims; i++)