[TorchToLinalg] Use `linalg.broadcast` instead of `generic` for conv bias (#3661)

The current implementation uses a `linalg.generic` to broadcast the bias
tensor for the lowering of convolutions. This is suboptimal for later
pattern matching. This patch changes it to use the respective named op,
`linalg.broadcast`, instead.
pull/3637/head
Felix Schneider 2024-08-26 20:29:11 +02:00 committed by GitHub
parent fa39d91357
commit 638ef14512
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 14 deletions

View File

@ -1080,21 +1080,16 @@ public:
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank(); auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
SmallVector<AffineMap> indexingMaps = { SmallVector<int64_t, 4> addedDimensions;
// bias is used to initialize the channels - dimension 1 of output // bias is used to initialize the channels - dimension 1 of
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, // output
rewriter.getAffineDimExpr(1), context), for (int i = 0; i < resultRank; ++i)
rewriter.getMultiDimIdentityMap(resultRank)}; if (i != 1)
SmallVector<utils::IteratorType> iteratorTypes( addedDimensions.push_back(i);
resultRank, utils::IteratorType::parallel);
outputTensor = rewriter outputTensor = rewriter
.create<linalg::GenericOp>( .create<linalg::BroadcastOp>(loc, bias, initTensor,
loc, initTensor.getType(), bias, initTensor, addedDimensions)
indexingMaps, iteratorTypes, ->getResult(0);
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
} }
auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto stridesAttr = rewriter.getI64VectorAttr(strideInts);