diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 92ed647ef..8117551bf 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -103,7 +103,8 @@ public: Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); - auto selfRank = adaptor.getSelf().getType().cast().getRank(); + auto selfRank = + adaptor.getSelf().getType().cast().getRank(); Type elementType = adaptor.getSelf().getType().cast().getElementType(); Value c1 = @@ -506,18 +507,18 @@ public: return rewriter.create(loc, intType, v); }; - SmallVector paddingIntValues; - if (!getListConstructElements(op.getPadding(), paddingIntValues)) + SmallVector paddingInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) { return rewriter.notifyMatchFailure( - op, "only support padding from a list construct"); - paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), - paddingIntValues); + op, "only support constant padding values"); + } SmallVector strideInts; if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); @@ -554,6 +555,8 @@ public: "invalid: groups must divide weight batch size evenly."); SmallVector dilationIntValues = getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector paddingIntValues = + getAsConstantIntValues(rewriter, loc, paddingInts); SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); @@ -651,8 +654,11 @@ public: } else { // Pad input - paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( - op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2); + SmallVector paddingIncludingNC = {0, 0}; + paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(), + paddingInts.end()); + paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input, + paddingIncludingNC); // Calculate output dims for (size_t i = 0; i < numSpacialDims; i++)