Includes some minor first for `AffineMap::inferFromExprList`
pull/2889/head
Rob Suderman 2024-02-09 14:07:49 -08:00 committed by GitHub
parent 7d33ba69ac
commit d83b576c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 18 additions and 14 deletions

@ -1 +1 @@
Subproject commit 70eb0e37a86747f9266e4c8380baa89746f5e23b Subproject commit bb180856ec28efe305dc77ca4bb3db12d8932edf

View File

@ -498,7 +498,8 @@ public:
resultExpr.push_back(rewriter.getAffineDimExpr(i)); resultExpr.push_back(rewriter.getAffineDimExpr(i));
} }
auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr},
rewriter.getContext());
Value finalRes = Value finalRes =
rewriter rewriter

View File

@ -512,8 +512,8 @@ public:
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
Value zeroTensor = Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, elementType); createZeroInitTensor(rewriter, loc, resultShape, elementType);
auto indexingMaps = auto indexingMaps = AffineMap::inferFromExprList(
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); {lhsExpr, rhsExpr, outExpr}, rewriter.getContext());
iteratorTypes.insert(iteratorTypes.end(), iteratorTypes.insert(iteratorTypes.end(),
{utils::IteratorType::parallel, {utils::IteratorType::parallel,
utils::IteratorType::reduction, utils::IteratorType::reduction,

View File

@ -442,8 +442,8 @@ public:
// Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH,
// and kW, respectively, as described in the algorithm above. // and kW, respectively, as described in the algorithm above.
SmallVector<AffineMap> indexingMaps = SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs}); {inputExprs, kernelExprs, outputExprs}, rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes( SmallVector<utils::IteratorType> iteratorTypes(
4, utils::IteratorType::parallel); 4, utils::IteratorType::parallel);
iteratorTypes.push_back(utils::IteratorType::reduction); iteratorTypes.push_back(utils::IteratorType::reduction);
@ -724,7 +724,7 @@ public:
kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2)); kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2));
kIterExprs.push_back(rewriter.getAffineDimExpr(3)); kIterExprs.push_back(rewriter.getAffineDimExpr(3));
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList( SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
{kIterExprs, outputExprs, kSizeTensorExprs}); {kIterExprs, outputExprs, kSizeTensorExprs}, rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes( SmallVector<utils::IteratorType> iteratorTypes(
3, utils::IteratorType::parallel); 3, utils::IteratorType::parallel);
iteratorTypes.push_back(utils::IteratorType::reduction); iteratorTypes.push_back(utils::IteratorType::reduction);
@ -774,8 +774,8 @@ public:
// make a linalg generic to divide each element by the corresponding // make a linalg generic to divide each element by the corresponding
// Kernel Width. This step is only necessary for avg pooling. // Kernel Width. This step is only necessary for avg pooling.
SmallVector<AffineMap> indexingMaps1 = SmallVector<AffineMap> indexingMaps1 = AffineMap::inferFromExprList(
AffineMap::inferFromExprList({kSizeTensorExprs, outputExprs}); {kSizeTensorExprs, outputExprs}, rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes1( SmallVector<utils::IteratorType> iteratorTypes1(
3, utils::IteratorType::parallel); 3, utils::IteratorType::parallel);
auto output = rewriter.create<linalg::GenericOp>( auto output = rewriter.create<linalg::GenericOp>(
@ -916,8 +916,8 @@ public:
for (unsigned i = rank; i < 2 * rank - 2; i++) { for (unsigned i = rank; i < 2 * rank - 2; i++) {
kIterExprs.push_back(rewriter.getAffineDimExpr(i)); kIterExprs.push_back(rewriter.getAffineDimExpr(i));
} }
SmallVector<AffineMap> indexingMaps = SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
AffineMap::inferFromExprList({kIterExprs, outputExprs, auxTensorExprs}); {kIterExprs, outputExprs, auxTensorExprs}, rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes( SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel); rank, utils::IteratorType::parallel);
for (unsigned i = 0; i < rank - 2; i++) { for (unsigned i = 0; i < rank - 2; i++) {

View File

@ -167,7 +167,8 @@ public:
resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); resultExprs.push_back(rewriter.getAffineDimExpr(size.index()));
} }
} }
auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs},
rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, loc,
ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}), ArrayRef<Type>({filledTensorVal.getType(), filledTensorIdx.getType()}),

View File

@ -197,7 +197,8 @@ Value torch_to_linalg::createReductionLinalgGeneric(
} }
} }
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); auto indexingMaps =
AffineMap::inferFromExprList({exprs, resultExprs}, b.getContext());
Value accumulator = Value accumulator =
createInitTensor(b, loc, resultShape, initElem.getType(), initElem); createInitTensor(b, loc, resultShape, initElem.getType(), initElem);

View File

@ -1064,7 +1064,8 @@ public:
rewriter.getAffineDimExpr(tensorOperandRank)); rewriter.getAffineDimExpr(tensorOperandRank));
SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList( SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
{originalIndicesDimExprs, updatedIndicesDimExprs}); {originalIndicesDimExprs, updatedIndicesDimExprs},
rewriter.getContext());
SmallVector<utils::IteratorType> iteratorTypes( SmallVector<utils::IteratorType> iteratorTypes(
tensorOperandRank + 1, utils::IteratorType::parallel); tensorOperandRank + 1, utils::IteratorType::parallel);