[TorchToLinalg] remove `extract_slice` grid_sample lowering (#3483)

Instead of using extract_slice for grid sampler, use affine constants to access the X and Y values in the generic op's region.
pull/3657/head
Ian Wood 2024-08-20 14:23:43 -07:00 committed by GitHub
parent f66908f190
commit a24114efa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 51 deletions

View File

@ -2431,9 +2431,7 @@ public:
Location loc = op->getLoc();
Type int64type = rewriter.getI64Type();
Type floatType = rewriter.getF32Type();
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value twoIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2);
Value zeroFloat = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(floatType, 0.0));
Value oneFloat = rewriter.create<arith::ConstantOp>(
@ -2442,7 +2440,6 @@ public:
loc, rewriter.getFloatAttr(floatType, 2.0));
Value input = adaptor.getInput();
auto inputType = cast<RankedTensorType>(input.getType());
auto inputShape = inputType.getShape();
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
Value innerDim0b =
@ -2463,42 +2460,21 @@ public:
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
Value grid = adaptor.getGrid();
auto gridType = cast<RankedTensorType>(grid.getType());
auto gridShape = gridType.getShape();
auto gridRank = gridType.getRank();
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
SmallVector<Value> extractGridShape = getTensorSizes(rewriter, loc, grid);
SmallVector<Value> extractGridStride(gridRank, oneIndex);
int64_t lastGridDim = gridRank - 1;
extractGridShape[lastGridDim] = oneIndex;
extractGridStride[lastGridDim] = twoIndex;
SmallVector<Value> extractGridOffsets1(gridRank, zeroIndex);
extractGridOffsets1[lastGridDim] = oneIndex;
SmallVector<int64_t> gridShapeExtracted(gridShape);
gridShapeExtracted.back() = 1;
SmallVector<int64_t> gridShapeCollapsed{gridShape[0], gridShape[1],
gridShape[2]};
auto grid0 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets0, extractGridShape, extractGridStride);
auto grid1 = rewriter.create<tensor::ExtractSliceOp>(
loc, grid, extractGridOffsets1, extractGridShape, extractGridStride);
SmallVector<ReassociationIndices> associations{ReassociationIndices{0},
ReassociationIndices{1},
ReassociationIndices{2, 3}};
auto gridCollapsed0 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid0, associations);
auto gridCollapsed1 =
rewriter.create<tensor::CollapseShapeOp>(loc, grid1, associations);
AffineMap gridMap = AffineMap::get(4, 0,
{rewriter.getAffineDimExpr(0),
rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3)},
op->getContext());
SmallVector<AffineMap> gridMaps{gridMap, gridMap,
rewriter.getMultiDimIdentityMap(gridRank)};
SmallVector<AffineMap> gridMaps{
AffineMap::get(
4, 0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(0)},
op->getContext()),
AffineMap::get(
4, 0,
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2),
rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(1)},
op->getContext()),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
SmallVector<utils::IteratorType> gridIterators(
gridRank, utils::IteratorType::parallel);
SmallVector<int64_t> resultShape{inputShape[0], inputShape[1], gridShape[1],
gridShape[2]};
auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
Value idxB, Value idxC, Value idxD) -> Value {
SmallVector<Value> index{idxA, idxB, idxC, idxD};
@ -2539,22 +2515,22 @@ public:
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
SmallVector<Value> resultSize{};
if (resultType.isDynamicDim(0))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (resultType.isDynamicDim(1))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (resultType.isDynamicDim(2))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
if (resultType.isDynamicDim(3))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode();
Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
SmallVector<Value> dynamicSizes{};
if (resultType.isDynamicDim(0))
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (resultType.isDynamicDim(1))
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (resultType.isDynamicDim(2))
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
if (resultType.isDynamicDim(3))
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
tensor::EmptyOp emptyOp =
rewriter.create<tensor::EmptyOp>(loc, resultType, dynamicSizes);
auto sGrid = rewriter.create<linalg::GenericOp>(
loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1},
ValueRange(resultFinal), gridMaps, gridIterators,
loc, TypeRange{resultType}, ValueRange{grid, grid}, ValueRange(emptyOp),
gridMaps, gridIterators,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value gr0 = args[1];
Value gr1 = args[0];

View File

@ -5,9 +5,7 @@
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32