mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] Use `linalg.broadcast` instead of `generic` for conv bias (#3661)
The current implementation uses a `linalg.generic` to broadcast the bias tensor for the lowering of convolutions. This is suboptimal for later pattern matching. This patch changes it to use the respective named op, `linalg.broadcast`, instead.pull/3637/head
parent
fa39d91357
commit
638ef14512
|
@ -1080,21 +1080,16 @@ public:
|
|||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||
|
||||
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
// bias is used to initialize the channels - dimension 1 of output
|
||||
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
||||
rewriter.getAffineDimExpr(1), context),
|
||||
rewriter.getMultiDimIdentityMap(resultRank)};
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
resultRank, utils::IteratorType::parallel);
|
||||
SmallVector<int64_t, 4> addedDimensions;
|
||||
// bias is used to initialize the channels - dimension 1 of
|
||||
// output
|
||||
for (int i = 0; i < resultRank; ++i)
|
||||
if (i != 1)
|
||||
addedDimensions.push_back(i);
|
||||
outputTensor = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
.create<linalg::BroadcastOp>(loc, bias, initTensor,
|
||||
addedDimensions)
|
||||
->getResult(0);
|
||||
}
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(strideInts);
|
||||
|
|
Loading…
Reference in New Issue