diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 76bf0c13d..52765411b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1080,21 +1080,16 @@ public: return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); auto resultRank = cast(initTensor.getType()).getRank(); - SmallVector indexingMaps = { - // bias is used to initialize the channels - dimension 1 of output - AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, - rewriter.getAffineDimExpr(1), context), - rewriter.getMultiDimIdentityMap(resultRank)}; - SmallVector iteratorTypes( - resultRank, utils::IteratorType::parallel); + SmallVector addedDimensions; + // bias is used to initialize the channels - dimension 1 of + // output + for (int i = 0; i < resultRank; ++i) + if (i != 1) + addedDimensions.push_back(i); outputTensor = rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + .create(loc, bias, initTensor, + addedDimensions) + ->getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(strideInts);