Remove unused input tensor from linalg.generic in aten.convolution (#1487)

This commit removes the `weight` tensor from the inputs of one of the
`linalg.generic` ops generated by the `aten.convolution` linalg
lowering, since the indexed values are not actually used by the body
of the `linalg.generic`. Moreover, in general the `weight` tensor does
not have the same shape as the output tensor of the `linalg.generic`,
so both tensors being indexed by the same indexing maps is wrong.
pull/1488/head
Ramiro Leal-Cavazos 2022-10-12 14:01:24 -07:00 committed by GitHub
parent b487113ef1
commit 8f76c74be9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -572,11 +572,11 @@ public:
createZeroInitTensor(rewriter, loc, weightInitDims, elementType); createZeroInitTensor(rewriter, loc, weightInitDims, elementType);
SmallVector<StringRef> iteratorTypes(inRank, SmallVector<StringRef> iteratorTypes(inRank,
getParallelIteratorTypeName()); getParallelIteratorTypeName());
SmallVector<AffineMap> indexingMaps( SmallVector<AffineMap> indexingMaps{
2, AffineMap::getMultiDimIdentityMap(inRank, context)); AffineMap::getMultiDimIdentityMap(inRank, context)};
weight = rewriter weight = rewriter
.create<linalg::GenericOp>( .create<linalg::GenericOp>(
loc, weightInitTensor.getType(), weight, loc, weightInitTensor.getType(), ValueRange{},
weightInitTensor, indexingMaps, iteratorTypes, weightInitTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indices; SmallVector<Value> indices;