mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg] remove `extract_slice` grid_sample lowering (#3483)
Instead of using extract_slice for grid sampler, use affine constants to access the X and Y values in the generic op's region.pull/3657/head
parent
f66908f190
commit
a24114efa3
|
@ -2431,9 +2431,7 @@ public:
|
|||
Location loc = op->getLoc();
|
||||
Type int64type = rewriter.getI64Type();
|
||||
Type floatType = rewriter.getF32Type();
|
||||
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value oneIndex = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value twoIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2);
|
||||
Value zeroFloat = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getFloatAttr(floatType, 0.0));
|
||||
Value oneFloat = rewriter.create<arith::ConstantOp>(
|
||||
|
@ -2442,7 +2440,6 @@ public:
|
|||
loc, rewriter.getFloatAttr(floatType, 2.0));
|
||||
Value input = adaptor.getInput();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto inputShape = inputType.getShape();
|
||||
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
|
||||
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
|
||||
Value innerDim0b =
|
||||
|
@ -2463,42 +2460,21 @@ public:
|
|||
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
|
||||
Value grid = adaptor.getGrid();
|
||||
auto gridType = cast<RankedTensorType>(grid.getType());
|
||||
auto gridShape = gridType.getShape();
|
||||
auto gridRank = gridType.getRank();
|
||||
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
|
||||
SmallVector<Value> extractGridShape = getTensorSizes(rewriter, loc, grid);
|
||||
SmallVector<Value> extractGridStride(gridRank, oneIndex);
|
||||
int64_t lastGridDim = gridRank - 1;
|
||||
extractGridShape[lastGridDim] = oneIndex;
|
||||
extractGridStride[lastGridDim] = twoIndex;
|
||||
SmallVector<Value> extractGridOffsets1(gridRank, zeroIndex);
|
||||
extractGridOffsets1[lastGridDim] = oneIndex;
|
||||
SmallVector<int64_t> gridShapeExtracted(gridShape);
|
||||
gridShapeExtracted.back() = 1;
|
||||
SmallVector<int64_t> gridShapeCollapsed{gridShape[0], gridShape[1],
|
||||
gridShape[2]};
|
||||
auto grid0 = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, grid, extractGridOffsets0, extractGridShape, extractGridStride);
|
||||
auto grid1 = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, grid, extractGridOffsets1, extractGridShape, extractGridStride);
|
||||
SmallVector<ReassociationIndices> associations{ReassociationIndices{0},
|
||||
ReassociationIndices{1},
|
||||
ReassociationIndices{2, 3}};
|
||||
auto gridCollapsed0 =
|
||||
rewriter.create<tensor::CollapseShapeOp>(loc, grid0, associations);
|
||||
auto gridCollapsed1 =
|
||||
rewriter.create<tensor::CollapseShapeOp>(loc, grid1, associations);
|
||||
AffineMap gridMap = AffineMap::get(4, 0,
|
||||
{rewriter.getAffineDimExpr(0),
|
||||
rewriter.getAffineDimExpr(2),
|
||||
rewriter.getAffineDimExpr(3)},
|
||||
op->getContext());
|
||||
SmallVector<AffineMap> gridMaps{gridMap, gridMap,
|
||||
rewriter.getMultiDimIdentityMap(gridRank)};
|
||||
SmallVector<AffineMap> gridMaps{
|
||||
AffineMap::get(
|
||||
4, 0,
|
||||
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2),
|
||||
rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(0)},
|
||||
op->getContext()),
|
||||
AffineMap::get(
|
||||
4, 0,
|
||||
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2),
|
||||
rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(1)},
|
||||
op->getContext()),
|
||||
rewriter.getMultiDimIdentityMap(inputType.getRank())};
|
||||
SmallVector<utils::IteratorType> gridIterators(
|
||||
gridRank, utils::IteratorType::parallel);
|
||||
SmallVector<int64_t> resultShape{inputShape[0], inputShape[1], gridShape[1],
|
||||
gridShape[2]};
|
||||
auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA,
|
||||
Value idxB, Value idxC, Value idxD) -> Value {
|
||||
SmallVector<Value> index{idxA, idxB, idxC, idxD};
|
||||
|
@ -2539,22 +2515,22 @@ public:
|
|||
|
||||
auto resultType = cast<RankedTensorType>(
|
||||
getTypeConverter()->convertType(op.getResult().getType()));
|
||||
SmallVector<Value> resultSize{};
|
||||
if (resultType.isDynamicDim(0))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
if (resultType.isDynamicDim(1))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
if (resultType.isDynamicDim(2))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
|
||||
if (resultType.isDynamicDim(3))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
|
||||
Value alignCorners = adaptor.getAlignCorners();
|
||||
Value interMode = adaptor.getInterpolationMode();
|
||||
Value resultFinal =
|
||||
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
|
||||
SmallVector<Value> dynamicSizes{};
|
||||
if (resultType.isDynamicDim(0))
|
||||
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
if (resultType.isDynamicDim(1))
|
||||
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
if (resultType.isDynamicDim(2))
|
||||
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
|
||||
if (resultType.isDynamicDim(3))
|
||||
dynamicSizes.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
|
||||
tensor::EmptyOp emptyOp =
|
||||
rewriter.create<tensor::EmptyOp>(loc, resultType, dynamicSizes);
|
||||
auto sGrid = rewriter.create<linalg::GenericOp>(
|
||||
loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1},
|
||||
ValueRange(resultFinal), gridMaps, gridIterators,
|
||||
loc, TypeRange{resultType}, ValueRange{grid, grid}, ValueRange(emptyOp),
|
||||
gridMaps, gridIterators,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value gr0 = args[1];
|
||||
Value gr1 = args[0];
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
|
||||
|
|
Loading…
Reference in New Issue