[Stablehlo] Add lowering of GridSampler Op (#3084)

Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
pull/3430/head
Xinyu Yang 2024-06-07 16:06:07 +08:00 committed by GitHub
parent 72837fbb3d
commit 431d98b405
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 413 additions and 1 deletions

View File

@ -13,7 +13,9 @@
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -900,7 +902,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
updateWindowDims.push_back(i);
}
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
/*updateWindowDims=*/updateWindowDims,
@ -941,6 +942,412 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
return success();
}
// AtenGridSamplerOp
// See
// https://github.com/pytorch/pytorch/blob/ec58f1f74ebcec744d2ab90ad34abd09c1018e92/torch/_decomp/decompositions.py#L3923-L4086
namespace {
template <typename T>
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (isa<mlir::IntegerType>(ty))
return b.getIntegerAttr(ty, constant);
if (isa<mlir::FloatType>(ty))
return b.getFloatAttr(ty, constant);
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
return complex::NumberAttr::get(complexTy, constant, 0);
llvm_unreachable("unhandled element type");
};
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
val);
}
template <typename T>
static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
ArrayRef<T> values, ArrayRef<int64_t> shape,
Type ty) {
Location loc = op->getLoc();
RankedTensorType valueType = RankedTensorType::get(shape, ty);
auto valueAttr = DenseElementsAttr::get(valueType, values);
return rewriter.create<stablehlo::ConstantOp>(loc, valueType, valueAttr);
}
template <typename T>
static Value getConstScalarTensor(ConversionPatternRewriter &rewriter,
Operation *op, T value, Type ty) {
return getConstTensor(rewriter, op, ArrayRef<T>{value}, {}, ty);
}
// Helper function to lower AtenGridSamplerOp.
static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op,
Value coords, int64_t size, Type elemTy,
bool alignCorners) {
Location loc = op->getLoc();
APFloat pointFive(cast<mlir::FloatType>(elemTy).getFloatSemantics(), "0.5");
APFloat sizeFloat =
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), size);
APFloat one = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 1);
APFloat zero = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
// double mul = alignCorners ? (size * 0.5 - 0.5) : (size * 0.5);
// double ofs = size * 0.5 - 0.5;
APFloat mul =
alignCorners ? sizeFloat * pointFive - pointFive : sizeFloat * pointFive;
APFloat ofs = sizeFloat * pointFive - pointFive;
Value constMul = getConstScalarTensor(rewriter, op, mul, elemTy);
Value constOfs = getConstScalarTensor(rewriter, op, ofs, elemTy);
// use chlo::BroadcastMulOp to multiply constMul with coords.
DenseI64ArrayAttr bcastDimensions;
Value mulResult = rewriter.create<chlo::BroadcastMulOp>(loc, coords, constMul,
bcastDimensions);
// use chlo::BroadcastAddOp to add constOfs to mulResult.
Value result = rewriter.create<chlo::BroadcastAddOp>(loc, mulResult, constOfs,
bcastDimensions);
return result;
}
static Value computeCoordinates(ConversionPatternRewriter &rewriter,
Operation *op, Value coords, int64_t size,
Type elemTy, int64_t padding_mode) {
// TODO: add support for padding_mode 1 and 2.
return coords;
}
static Value computeSourceIndex(ConversionPatternRewriter &rewriter,
Operation *op, Value coords, int64_t size,
Type elemTy, int64_t padding_mode,
bool alignCorners) {
Value coordsUn =
unnormalize(rewriter, op, coords, size, elemTy, alignCorners);
return computeCoordinates(rewriter, op, coordsUn, size, elemTy, padding_mode);
}
// def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
// return torch.logical_and(
// 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys
// < iH))
// )
static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op,
Value xs, Value ys, int64_t ih, int64_t iw,
Type elemTy) {
Location loc = op->getLoc();
APFloat zeroFloat =
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
Value zero = getConstScalarTensor(rewriter, op, zeroFloat, elemTy);
APFloat iwFloat =
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), iw);
APFloat ihFloat =
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), ih);
Value iwFloatValue = getConstScalarTensor(rewriter, op, iwFloat, elemTy);
Value ihFloatValue = getConstScalarTensor(rewriter, op, ihFloat, elemTy);
chlo::ComparisonTypeAttr compareTypeAttr = chlo::ComparisonTypeAttr::get(
rewriter.getContext(), chlo::ComparisonType::FLOAT);
chlo::ComparisonDirectionAttr compareLTAttr =
chlo::ComparisonDirectionAttr::get(rewriter.getContext(),
chlo::ComparisonDirection::LT);
chlo::ComparisonDirectionAttr compareGEAttr =
chlo::ComparisonDirectionAttr::get(rewriter.getContext(),
chlo::ComparisonDirection::GE);
DenseI64ArrayAttr bcastDimensions;
Value cond1 = rewriter.create<chlo::BroadcastCompareOp>(
loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr);
Value cond2 = rewriter.create<chlo::BroadcastCompareOp>(
loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr);
Value cond3 = rewriter.create<chlo::BroadcastCompareOp>(
loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr);
Value cond4 = rewriter.create<chlo::BroadcastCompareOp>(
loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr);
Value cond5 =
rewriter.create<chlo::BroadcastAndOp>(loc, cond1, cond2, bcastDimensions);
Value cond6 =
rewriter.create<chlo::BroadcastAndOp>(loc, cond3, cond4, bcastDimensions);
return rewriter.create<chlo::BroadcastAndOp>(loc, cond5, cond6,
bcastDimensions);
}
// def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
// cond = in_bounds_cond(xs, ys)
// # To clip to inside valid coordinates, we map the coordinates
// # to (x, y) = (0, 0) and also set the weight to 0
// # We also change the shape of the tensor to the appropriate one for
// # broadcasting with N_idx, C_idx for the purposes of advanced
// indexing c = C if _expand_grid else 1
// return tuple(
// torch.where(cond, t, 0).view(N, c, oH, oW)
// for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
// )
SmallVector<Value> clip(ConversionPatternRewriter &rewriter, Operation *op,
Value xs, Value ys, Value ws, int64_t N, int64_t oH,
int64_t oW, int64_t iH, int64_t iW, Type elemTy) {
Location loc = op->getLoc();
auto indexElemTy = rewriter.getI64Type();
auto indexTy = RankedTensorType::get(mlir::ArrayRef<int64_t>{1}, indexElemTy);
Value zeroIntValue = rewriter.create<stablehlo::ConstantOp>(
loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef<int64_t>{0}));
APFloat zeroAPFloat =
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
Value zeroFloatValue =
getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy);
Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy);
Value xsInt = rewriter.create<stablehlo::ConvertOp>(loc, xs, indexElemTy);
Value ysInt = rewriter.create<stablehlo::ConvertOp>(loc, ys, indexElemTy);
Value selectXs = rewriter.create<chlo::BroadcastSelectOp>(
loc, ArrayRef<Value>{cond, xsInt, zeroIntValue});
Value selectYs = rewriter.create<chlo::BroadcastSelectOp>(
loc, ArrayRef<Value>{cond, ysInt, zeroIntValue});
Value selectWs = rewriter.create<chlo::BroadcastSelectOp>(
loc, ArrayRef<Value>{cond, ws, zeroFloatValue});
SmallVector<int64_t> sizes = {N, 1, oH, oW};
Value reshapedXs = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(sizes, indexElemTy), selectXs);
Value reshapedYs = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(sizes, indexElemTy), selectYs);
Value reshapedWs = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(sizes, elemTy), selectWs);
return SmallVector<Value>{reshapedXs, reshapedYs, reshapedWs};
}
Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
Value input, Value ix, Value iy, Value w, int64_t N,
int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx,
Value CIdx, RankedTensorType outType, Type elemTy) {
Location loc = op->getLoc();
auto inputTensorType = cast<RankedTensorType>(input.getType());
SmallVector<Value> clipValues =
clip(rewriter, op, ix, iy, w, N, oH, oW, iH, iW, elemTy);
Value idxX = clipValues[0];
Value idxY = clipValues[1];
Value idxW = clipValues[2];
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
int maxIndexRank = -1;
auto gatherIndicesInfo =
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
outType.getShape(), maxIndexRank);
auto gatherIndices = *gatherIndicesInfo;
int64_t numIndicesDim = indexTensors.size();
int64_t indexVecDim = maxIndexRank;
SmallVector<int64_t> offsetDims;
SmallVector<int64_t> collapsedDims;
SmallVector<int64_t> startIndexMap;
for (int64_t i = 0; i < numIndicesDim; ++i) {
collapsedDims.push_back(i);
startIndexMap.push_back(i);
}
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
offsetDims.push_back(i + maxIndexRank - numIndicesDim);
}
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedDims,
/*operandBatchingDims=*/{},
/*startIndicesBatchingDims=*/{},
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);
SmallVector<int64_t> sliceSizes;
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
if (i < numIndicesDim) {
sliceSizes.push_back(1);
} else {
sliceSizes.push_back(inputShape[i]);
}
}
Value gather = rewriter.create<stablehlo::GatherOp>(
loc, input, gatherIndices, dimsAttr,
rewriter.getDenseI64ArrayAttr(sliceSizes));
// use chlo::BroadcastMulOp to multiply idxW with gather.
DenseI64ArrayAttr bcastDimensions;
return rewriter.create<chlo::BroadcastMulOp>(loc, gather, idxW,
bcastDimensions);
}
} // namespace
template <>
LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
AtenGridSamplerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value input = adaptor.getInput();
Value grid = adaptor.getGrid();
int64_t interpolationMode;
if (!matchPattern(op.getInterpolationMode(),
m_TorchConstantInt(&interpolationMode)))
return rewriter.notifyMatchFailure(
op, "interpolation_mode must be an integer constant");
int64_t paddingMode;
if (!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingMode)))
return rewriter.notifyMatchFailure(
op, "padding_mode must be an integer constant");
if (interpolationMode != 0 && interpolationMode != 1)
return rewriter.notifyMatchFailure(
op, "only support interpolation_mode = 0 (bilinear) or 1(nearest)");
if (paddingMode != 0)
return rewriter.notifyMatchFailure(op,
"only support paddingMode = 0 (Zero)");
bool alignCorners = false;
if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners)))
return rewriter.notifyMatchFailure(
op, "alignCorners must be a boolean constant");
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
RankedTensorType gridTy = cast<RankedTensorType>(grid.getType());
RankedTensorType outTy =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
Type elemTy = inputTy.getElementType();
if (inputTy.getRank() != 4)
return rewriter.notifyMatchFailure(op, "input must be a 4D tensor");
if (gridTy.getRank() != 4)
return rewriter.notifyMatchFailure(op, "grid must be a 4D tensor");
auto inputSize = inputTy.getShape();
auto gridSize = gridTy.getShape();
int64_t N = inputSize[0];
int64_t C = inputSize[1];
int64_t iH = inputSize[2];
int64_t iW = inputSize[3];
int64_t oH = gridSize[1];
int64_t oW = gridSize[2];
// grid is a 4D tensor with shape (N, oH, oW, 2)
Type indexElemTy = rewriter.getI64Type();
RankedTensorType indexTy =
RankedTensorType::get(mlir::ArrayRef<int64_t>{1}, indexElemTy);
Value constN = rewriter.create<stablehlo::ConstantOp>(
loc, indexTy, DenseIntElementsAttr::get(indexTy, {N}));
Value constC = rewriter.create<stablehlo::ConstantOp>(
loc, indexTy, DenseIntElementsAttr::get(indexTy, {C}));
APFloat one = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 1);
APFloat zero = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy);
auto NidxFlatten = rewriter.create<stablehlo::DynamicIotaOp>(
loc, RankedTensorType::get(mlir::ArrayRef<int64_t>{N}, indexElemTy),
constN, 0);
auto CidxFlatten = rewriter.create<stablehlo::DynamicIotaOp>(
loc, RankedTensorType::get(mlir::ArrayRef<int64_t>{C}, indexElemTy),
constC, 0);
// Reshape NidxFlatten to 4D tensor (N, 1, 1, 1)
auto NidxSizes = mlir::SmallVector<int64_t>{N, 1, 1, 1};
auto Nidx = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten);
// Reshape CidxFlatten to 4D tensor (1, C, 1, 1)
auto CidxSizes = mlir::SmallVector<int64_t>{1, C, 1, 1};
auto Cidx = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten);
llvm::SmallVector<int64_t> stride(4, 1);
auto gridX = rewriter.create<stablehlo::SliceOp>(
loc,
RankedTensorType::get(mlir::SmallVector<int64_t>{N, oH, oW, 1},
gridTy.getElementType()),
grid, mlir::SmallVector<int64_t>{0, 0, 0, 0},
mlir::SmallVector<int64_t>{N, oH, oW, 1}, stride);
auto gridY = rewriter.create<stablehlo::SliceOp>(
loc,
RankedTensorType::get(mlir::SmallVector<int64_t>{N, oH, oW, 1},
gridTy.getElementType()),
grid, mlir::SmallVector<int64_t>{0, 0, 0, 1},
mlir::SmallVector<int64_t>{N, oH, oW, 2}, stride);
// squeeze last dimension
auto gridXshape = mlir::SmallVector<int64_t>{N, oH, oW};
auto gridXReshape = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX);
auto gridYReshape = rewriter.create<stablehlo::ReshapeOp>(
loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY);
if (interpolationMode == 0) {
Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy,
paddingMode, alignCorners);
Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy,
paddingMode, alignCorners);
Value ix_nw = rewriter.create<stablehlo::FloorOp>(loc, ix);
Value iy_nw = rewriter.create<stablehlo::FloorOp>(loc, iy);
DenseI64ArrayAttr bcastDimensions;
Value ix_ne = rewriter.create<chlo::BroadcastAddOp>(
loc, ix_nw, constOneFloat, bcastDimensions);
Value iy_ne = iy_nw;
Value ix_sw = ix_nw;
Value iy_sw = rewriter.create<chlo::BroadcastAddOp>(
loc, iy_nw, constOneFloat, bcastDimensions);
Value ix_se = ix_ne;
Value iy_se = iy_sw;
// w_nw = (ix_se - ix) * (iy_se - iy)
// w_ne = (ix - ix_sw) * (iy_sw - iy)
// w_sw = (ix_ne - ix) * (iy - iy_ne)
// w_se = (ix - ix_nw) * (iy - iy_nw)
Value w_nw = rewriter.create<chlo::BroadcastMulOp>(
loc,
rewriter.create<chlo::BroadcastSubOp>(loc, ix_se, ix, bcastDimensions),
rewriter.create<chlo::BroadcastSubOp>(loc, iy_se, iy, bcastDimensions),
bcastDimensions);
Value w_ne = rewriter.create<chlo::BroadcastMulOp>(
loc,
rewriter.create<chlo::BroadcastSubOp>(loc, ix, ix_sw, bcastDimensions),
rewriter.create<chlo::BroadcastSubOp>(loc, iy_sw, iy, bcastDimensions),
bcastDimensions);
Value w_sw = rewriter.create<chlo::BroadcastMulOp>(
loc,
rewriter.create<chlo::BroadcastSubOp>(loc, ix_ne, ix, bcastDimensions),
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_ne, bcastDimensions),
bcastDimensions);
Value w_se = rewriter.create<chlo::BroadcastMulOp>(
loc,
rewriter.create<chlo::BroadcastSubOp>(loc, ix, ix_nw, bcastDimensions),
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_nw, bcastDimensions),
bcastDimensions);
Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N,
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
// summand_nw + summand_ne + summand_sw + summand_se
Value sum = rewriter.create<stablehlo::AddOp>(loc, summand_nw, summand_ne);
sum = rewriter.create<stablehlo::AddOp>(loc, sum, summand_sw);
sum = rewriter.create<stablehlo::AddOp>(loc, sum, summand_se);
rewriter.replaceOp(op, sum);
} else if (interpolationMode == 1) {
Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy,
paddingMode, alignCorners);
Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy,
paddingMode, alignCorners);
Value ix_round = rewriter.create<stablehlo::RoundOp>(loc, ix);
Value iy_round = rewriter.create<stablehlo::RoundOp>(loc, iy);
Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round);
Value summand =
getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH,
oW, iH, iW, Nidx, Cidx, outTy, elemTy);
rewriter.replaceOp(op, summand);
}
return success();
}
void mlir::torch::torch_to_stablehlo::
populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
@ -957,6 +1364,7 @@ void mlir::torch::torch_to_stablehlo::
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
INSERT_ATENOP_PATTERN(AtenGridSamplerOp);
#undef INSERT_ATENOP_PATTERN
#define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \

View File

@ -1080,6 +1080,10 @@ STABLEHLO_PASS_SET = {
"GeIntModule_basic",
"GeluBackwardModule_basic",
"GluStaticModule_basic",
"GridSamplerBasic1_basic",
"GridSamplerBasic2_basic",
"GridSamplerBasic3_basic",
"GridSamplerBasic4_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic",