[TorchToLinalg] Use `linalg.broadcast` instead of `generic` for conv bias (#3661)

The current implementation uses a `linalg.generic` to broadcast the bias
tensor for the lowering of convolutions. This is suboptimal for later
pattern matching. This patch changes it to use the respective named op,
`linalg.broadcast`, instead.
pull/3637/head
Felix Schneider 2024-08-26 20:29:11 +02:00 committed by GitHub
parent fa39d91357
commit 638ef14512
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 14 deletions

View File

@ -1080,21 +1080,16 @@ public:
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
SmallVector<AffineMap> indexingMaps = {
// bias is used to initialize the channels - dimension 1 of output
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
rewriter.getAffineDimExpr(1), context),
rewriter.getMultiDimIdentityMap(resultRank)};
SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);
SmallVector<int64_t, 4> addedDimensions;
// bias is used to initialize the channels - dimension 1 of
// output
for (int i = 0; i < resultRank; ++i)
if (i != 1)
addedDimensions.push_back(i);
outputTensor = rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
.create<linalg::BroadcastOp>(loc, bias, initTensor,
addedDimensions)
->getResult(0);
}
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);