mirror of https://github.com/llvm/torch-mlir
Remove dynamic padding.
parent
1eccac1264
commit
44b6cb0973
|
@ -103,7 +103,8 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
auto selfRank =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||||
Value c1 =
|
Value c1 =
|
||||||
|
@ -506,18 +507,18 @@ public:
|
||||||
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<Value> paddingIntValues;
|
SmallVector<int64_t> paddingInts;
|
||||||
if (!getListConstructElements(op.getPadding(), paddingIntValues))
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support padding from a list construct");
|
op, "only support constant padding values");
|
||||||
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
}
|
||||||
paddingIntValues);
|
|
||||||
SmallVector<int64_t> strideInts;
|
SmallVector<int64_t> strideInts;
|
||||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int strides");
|
"only support constant int strides");
|
||||||
SmallVector<int64_t> dilationInts;
|
SmallVector<int64_t> dilationInts;
|
||||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
if (!matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilationInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int dilations");
|
"only support constant int dilations");
|
||||||
|
|
||||||
|
@ -554,6 +555,8 @@ public:
|
||||||
"invalid: groups must divide weight batch size evenly.");
|
"invalid: groups must divide weight batch size evenly.");
|
||||||
SmallVector<Value> dilationIntValues =
|
SmallVector<Value> dilationIntValues =
|
||||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||||
|
SmallVector<Value> paddingIntValues =
|
||||||
|
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||||
SmallVector<Value> strideIntValues =
|
SmallVector<Value> strideIntValues =
|
||||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||||
|
|
||||||
|
@ -651,8 +654,11 @@ public:
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Pad input
|
// Pad input
|
||||||
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||||
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
|
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||||
|
paddingInts.end());
|
||||||
|
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
|
||||||
|
paddingIncludingNC);
|
||||||
|
|
||||||
// Calculate output dims
|
// Calculate output dims
|
||||||
for (size_t i = 0; i < numSpacialDims; i++)
|
for (size_t i = 0; i < numSpacialDims; i++)
|
||||||
|
|
Loading…
Reference in New Issue