From a24114efa316abede269189df1e2e712b5968721 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:23:43 -0700 Subject: [PATCH] [TorchToLinalg] remove `extract_slice` grid_sample lowering (#3483) Instead of using extract_slice for grid sampler, use affine constants to access the X and Y values in the generic op's region. --- .../TorchToLinalg/Uncategorized.cpp | 74 +++++++------------ .../Conversion/TorchToLinalg/gridsampler.mlir | 2 - 2 files changed, 25 insertions(+), 51 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 7823138c9..31f1a723f 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2431,9 +2431,7 @@ public: Location loc = op->getLoc(); Type int64type = rewriter.getI64Type(); Type floatType = rewriter.getF32Type(); - Value zeroIndex = rewriter.create(loc, 0); Value oneIndex = rewriter.create(loc, 1); - Value twoIndex = rewriter.create(loc, 2); Value zeroFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 0.0)); Value oneFloat = rewriter.create( @@ -2442,7 +2440,6 @@ public: loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); auto inputType = cast(input.getType()); - auto inputShape = inputType.getShape(); Value innerDim0a = rewriter.create(loc, input, 2); Value innerDim1a = rewriter.create(loc, input, 3); Value innerDim0b = @@ -2463,42 +2460,21 @@ public: rewriter.create(loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); auto gridType = cast(grid.getType()); - auto gridShape = gridType.getShape(); auto gridRank = gridType.getRank(); - SmallVector extractGridOffsets0(gridRank, zeroIndex); - SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); - SmallVector extractGridStride(gridRank, oneIndex); - int64_t lastGridDim = gridRank - 1; - extractGridShape[lastGridDim] = oneIndex; - extractGridStride[lastGridDim] = twoIndex; - SmallVector extractGridOffsets1(gridRank, zeroIndex); - extractGridOffsets1[lastGridDim] = oneIndex; - SmallVector gridShapeExtracted(gridShape); - gridShapeExtracted.back() = 1; - SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], - gridShape[2]}; - auto grid0 = rewriter.create( - loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); - auto grid1 = rewriter.create( - loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); - SmallVector associations{ReassociationIndices{0}, - ReassociationIndices{1}, - ReassociationIndices{2, 3}}; - auto gridCollapsed0 = - rewriter.create(loc, grid0, associations); - auto gridCollapsed1 = - rewriter.create(loc, grid1, associations); - AffineMap gridMap = AffineMap::get(4, 0, - {rewriter.getAffineDimExpr(0), - rewriter.getAffineDimExpr(2), - rewriter.getAffineDimExpr(3)}, - op->getContext()); - SmallVector gridMaps{gridMap, gridMap, - rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridMaps{ + AffineMap::get( + 4, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(0)}, + op->getContext()), + AffineMap::get( + 4, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(1)}, + op->getContext()), + rewriter.getMultiDimIdentityMap(inputType.getRank())}; SmallVector gridIterators( gridRank, utils::IteratorType::parallel); - SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], - gridShape[2]}; auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, Value idxB, Value idxC, Value idxD) -> Value { SmallVector index{idxA, idxB, idxC, idxD}; @@ -2539,22 +2515,22 @@ public: auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); - SmallVector resultSize{}; - if (resultType.isDynamicDim(0)) - resultSize.push_back(rewriter.create(loc, input, 0)); - if (resultType.isDynamicDim(1)) - resultSize.push_back(rewriter.create(loc, input, 1)); - if (resultType.isDynamicDim(2)) - resultSize.push_back(rewriter.create(loc, grid, 1)); - if (resultType.isDynamicDim(3)) - resultSize.push_back(rewriter.create(loc, grid, 2)); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); - Value resultFinal = - rewriter.create(loc, resultType, resultSize); + SmallVector dynamicSizes{}; + if (resultType.isDynamicDim(0)) + dynamicSizes.push_back(rewriter.create(loc, input, 0)); + if (resultType.isDynamicDim(1)) + dynamicSizes.push_back(rewriter.create(loc, input, 1)); + if (resultType.isDynamicDim(2)) + dynamicSizes.push_back(rewriter.create(loc, grid, 1)); + if (resultType.isDynamicDim(3)) + dynamicSizes.push_back(rewriter.create(loc, grid, 2)); + tensor::EmptyOp emptyOp = + rewriter.create(loc, resultType, dynamicSizes); auto sGrid = rewriter.create( - loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, - ValueRange(resultFinal), gridMaps, gridIterators, + loc, TypeRange{resultType}, ValueRange{grid, grid}, ValueRange(emptyOp), + gridMaps, gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value gr0 = args[1]; Value gr1 = args[0]; diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 7c099c5ce..2a291f721 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -5,9 +5,7 @@ // CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> // CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32