mirror of https://github.com/llvm/torch-mlir
Remove dynamic padding.
parent
1eccac1264
commit
44b6cb0973
|
@ -103,7 +103,8 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
auto selfRank =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||
Type elementType =
|
||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||
Value c1 =
|
||||
|
@ -506,18 +507,18 @@ public:
|
|||
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
||||
};
|
||||
|
||||
SmallVector<Value> paddingIntValues;
|
||||
if (!getListConstructElements(op.getPadding(), paddingIntValues))
|
||||
SmallVector<int64_t> paddingInts;
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support padding from a list construct");
|
||||
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||||
paddingIntValues);
|
||||
op, "only support constant padding values");
|
||||
}
|
||||
SmallVector<int64_t> strideInts;
|
||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int strides");
|
||||
SmallVector<int64_t> dilationInts;
|
||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
||||
if (!matchPattern(op.getDilation(),
|
||||
m_TorchListOfConstantInts(dilationInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int dilations");
|
||||
|
||||
|
@ -554,6 +555,8 @@ public:
|
|||
"invalid: groups must divide weight batch size evenly.");
|
||||
SmallVector<Value> dilationIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||
SmallVector<Value> paddingIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
|
@ -651,8 +654,11 @@ public:
|
|||
|
||||
} else {
|
||||
// Pad input
|
||||
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
|
||||
paddingIncludingNC);
|
||||
|
||||
// Calculate output dims
|
||||
for (size_t i = 0; i < numSpacialDims; i++)
|
||||
|
|
Loading…
Reference in New Issue