diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 7cfa3295f..05c52483c 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -13,7 +13,9 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -900,7 +902,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (int64_t i = maxIndexRank; i < inputRank; ++i) { updateWindowDims.push_back(i); } - auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), /*updateWindowDims=*/updateWindowDims, @@ -941,6 +942,412 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenGridSamplerOp +// See +// https://github.com/pytorch/pytorch/blob/ec58f1f74ebcec744d2ab90ad34abd09c1018e92/torch/_decomp/decompositions.py#L3923-L4086 +namespace { +template +static Value getConstantLike(OpBuilder &b, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (isa(ty)) + return b.getIntegerAttr(ty, constant); + if (isa(ty)) + return b.getFloatAttr(ty, constant); + if (auto complexTy = dyn_cast(ty)) + return complex::NumberAttr::get(complexTy, constant, 0); + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, cast(getAttr()), + val); +} + +template +static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + ArrayRef values, ArrayRef shape, + Type ty) { + Location loc = op->getLoc(); + RankedTensorType valueType = RankedTensorType::get(shape, ty); + auto valueAttr = DenseElementsAttr::get(valueType, values); + return rewriter.create(loc, valueType, valueAttr); +} + +template +static Value getConstScalarTensor(ConversionPatternRewriter &rewriter, + Operation *op, T value, Type ty) { + return getConstTensor(rewriter, op, ArrayRef{value}, {}, ty); +} + +// Helper function to lower AtenGridSamplerOp. +static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op, + Value coords, int64_t size, Type elemTy, + bool alignCorners) { + Location loc = op->getLoc(); + APFloat pointFive(cast(elemTy).getFloatSemantics(), "0.5"); + APFloat sizeFloat = + APFloat(cast(elemTy).getFloatSemantics(), size); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + // double mul = alignCorners ? (size * 0.5 - 0.5) : (size * 0.5); + // double ofs = size * 0.5 - 0.5; + APFloat mul = + alignCorners ? sizeFloat * pointFive - pointFive : sizeFloat * pointFive; + APFloat ofs = sizeFloat * pointFive - pointFive; + Value constMul = getConstScalarTensor(rewriter, op, mul, elemTy); + Value constOfs = getConstScalarTensor(rewriter, op, ofs, elemTy); + + // use chlo::BroadcastMulOp to multiply constMul with coords. + DenseI64ArrayAttr bcastDimensions; + Value mulResult = rewriter.create(loc, coords, constMul, + bcastDimensions); + // use chlo::BroadcastAddOp to add constOfs to mulResult. + Value result = rewriter.create(loc, mulResult, constOfs, + bcastDimensions); + return result; +} + +static Value computeCoordinates(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode) { + // TODO: add support for padding_mode 1 and 2. + return coords; +} + +static Value computeSourceIndex(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode, + bool alignCorners) { + Value coordsUn = + unnormalize(rewriter, op, coords, size, elemTy, alignCorners); + return computeCoordinates(rewriter, op, coordsUn, size, elemTy, padding_mode); +} + +// def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: +// return torch.logical_and( +// 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys +// < iH)) +// ) +static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, int64_t ih, int64_t iw, + Type elemTy) { + Location loc = op->getLoc(); + APFloat zeroFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zero = getConstScalarTensor(rewriter, op, zeroFloat, elemTy); + APFloat iwFloat = + APFloat(cast(elemTy).getFloatSemantics(), iw); + APFloat ihFloat = + APFloat(cast(elemTy).getFloatSemantics(), ih); + + Value iwFloatValue = getConstScalarTensor(rewriter, op, iwFloat, elemTy); + Value ihFloatValue = getConstScalarTensor(rewriter, op, ihFloat, elemTy); + + chlo::ComparisonTypeAttr compareTypeAttr = chlo::ComparisonTypeAttr::get( + rewriter.getContext(), chlo::ComparisonType::FLOAT); + chlo::ComparisonDirectionAttr compareLTAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::LT); + chlo::ComparisonDirectionAttr compareGEAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::GE); + DenseI64ArrayAttr bcastDimensions; + Value cond1 = rewriter.create( + loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond2 = rewriter.create( + loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond3 = rewriter.create( + loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond4 = rewriter.create( + loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond5 = + rewriter.create(loc, cond1, cond2, bcastDimensions); + Value cond6 = + rewriter.create(loc, cond3, cond4, bcastDimensions); + return rewriter.create(loc, cond5, cond6, + bcastDimensions); +} +// def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: +// cond = in_bounds_cond(xs, ys) +// # To clip to inside valid coordinates, we map the coordinates +// # to (x, y) = (0, 0) and also set the weight to 0 +// # We also change the shape of the tensor to the appropriate one for +// # broadcasting with N_idx, C_idx for the purposes of advanced +// indexing c = C if _expand_grid else 1 +// return tuple( +// torch.where(cond, t, 0).view(N, c, oH, oW) +// for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) +// ) +SmallVector clip(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, Value ws, int64_t N, int64_t oH, + int64_t oW, int64_t iH, int64_t iW, Type elemTy) { + Location loc = op->getLoc(); + auto indexElemTy = rewriter.getI64Type(); + auto indexTy = RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + + Value zeroIntValue = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef{0})); + + APFloat zeroAPFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zeroFloatValue = + getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy); + Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy); + Value xsInt = rewriter.create(loc, xs, indexElemTy); + Value ysInt = rewriter.create(loc, ys, indexElemTy); + + Value selectXs = rewriter.create( + loc, ArrayRef{cond, xsInt, zeroIntValue}); + Value selectYs = rewriter.create( + loc, ArrayRef{cond, ysInt, zeroIntValue}); + Value selectWs = rewriter.create( + loc, ArrayRef{cond, ws, zeroFloatValue}); + + SmallVector sizes = {N, 1, oH, oW}; + Value reshapedXs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectXs); + Value reshapedYs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectYs); + Value reshapedWs = rewriter.create( + loc, RankedTensorType::get(sizes, elemTy), selectWs); + return SmallVector{reshapedXs, reshapedYs, reshapedWs}; +} + +Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, + Value input, Value ix, Value iy, Value w, int64_t N, + int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx, + Value CIdx, RankedTensorType outType, Type elemTy) { + Location loc = op->getLoc(); + auto inputTensorType = cast(input.getType()); + SmallVector clipValues = + clip(rewriter, op, ix, iy, w, N, oH, oW, iH, iW, elemTy); + Value idxX = clipValues[0]; + Value idxY = clipValues[1]; + Value idxW = clipValues[2]; + SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; + + int maxIndexRank = -1; + auto gatherIndicesInfo = + broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, + outType.getShape(), maxIndexRank); + auto gatherIndices = *gatherIndicesInfo; + int64_t numIndicesDim = indexTensors.size(); + int64_t indexVecDim = maxIndexRank; + + SmallVector offsetDims; + SmallVector collapsedDims; + SmallVector startIndexMap; + for (int64_t i = 0; i < numIndicesDim; ++i) { + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } + for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) { + offsetDims.push_back(i + maxIndexRank - numIndicesDim); + } + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + SmallVector sliceSizes; + auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape()); + for (int64_t i = 0; i < inputTensorType.getRank(); ++i) { + if (i < numIndicesDim) { + sliceSizes.push_back(1); + } else { + sliceSizes.push_back(inputShape[i]); + } + } + + Value gather = rewriter.create( + loc, input, gatherIndices, dimsAttr, + rewriter.getDenseI64ArrayAttr(sliceSizes)); + // use chlo::BroadcastMulOp to multiply idxW with gather. + DenseI64ArrayAttr bcastDimensions; + return rewriter.create(loc, gather, idxW, + bcastDimensions); +} + +} // namespace +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value grid = adaptor.getGrid(); + + int64_t interpolationMode; + if (!matchPattern(op.getInterpolationMode(), + m_TorchConstantInt(&interpolationMode))) + return rewriter.notifyMatchFailure( + op, "interpolation_mode must be an integer constant"); + int64_t paddingMode; + if (!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingMode))) + return rewriter.notifyMatchFailure( + op, "padding_mode must be an integer constant"); + + if (interpolationMode != 0 && interpolationMode != 1) + return rewriter.notifyMatchFailure( + op, "only support interpolation_mode = 0 (bilinear) or 1(nearest)"); + + if (paddingMode != 0) + return rewriter.notifyMatchFailure(op, + "only support paddingMode = 0 (Zero)"); + + bool alignCorners = false; + if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners))) + return rewriter.notifyMatchFailure( + op, "alignCorners must be a boolean constant"); + + RankedTensorType inputTy = cast(input.getType()); + RankedTensorType gridTy = cast(grid.getType()); + RankedTensorType outTy = + cast(getTypeConverter()->convertType(op.getType())); + Type elemTy = inputTy.getElementType(); + if (inputTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "input must be a 4D tensor"); + if (gridTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "grid must be a 4D tensor"); + + auto inputSize = inputTy.getShape(); + auto gridSize = gridTy.getShape(); + int64_t N = inputSize[0]; + int64_t C = inputSize[1]; + int64_t iH = inputSize[2]; + int64_t iW = inputSize[3]; + int64_t oH = gridSize[1]; + int64_t oW = gridSize[2]; + // grid is a 4D tensor with shape (N, oH, oW, 2) + + Type indexElemTy = rewriter.getI64Type(); + RankedTensorType indexTy = + RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + Value constN = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {N})); + Value constC = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {C})); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy); + + auto NidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{N}, indexElemTy), + constN, 0); + auto CidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{C}, indexElemTy), + constC, 0); + + // Reshape NidxFlatten to 4D tensor (N, 1, 1, 1) + auto NidxSizes = mlir::SmallVector{N, 1, 1, 1}; + auto Nidx = rewriter.create( + loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten); + + // Reshape CidxFlatten to 4D tensor (1, C, 1, 1) + auto CidxSizes = mlir::SmallVector{1, C, 1, 1}; + auto Cidx = rewriter.create( + loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten); + + llvm::SmallVector stride(4, 1); + auto gridX = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 0}, + mlir::SmallVector{N, oH, oW, 1}, stride); + auto gridY = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 1}, + mlir::SmallVector{N, oH, oW, 2}, stride); + // squeeze last dimension + auto gridXshape = mlir::SmallVector{N, oH, oW}; + + auto gridXReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX); + auto gridYReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY); + + if (interpolationMode == 0) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_nw = rewriter.create(loc, ix); + Value iy_nw = rewriter.create(loc, iy); + + DenseI64ArrayAttr bcastDimensions; + Value ix_ne = rewriter.create( + loc, ix_nw, constOneFloat, bcastDimensions); + Value iy_ne = iy_nw; + Value ix_sw = ix_nw; + Value iy_sw = rewriter.create( + loc, iy_nw, constOneFloat, bcastDimensions); + Value ix_se = ix_ne; + Value iy_se = iy_sw; + + // w_nw = (ix_se - ix) * (iy_se - iy) + // w_ne = (ix - ix_sw) * (iy_sw - iy) + // w_sw = (ix_ne - ix) * (iy - iy_ne) + // w_se = (ix - ix_nw) * (iy - iy_nw) + Value w_nw = rewriter.create( + loc, + rewriter.create(loc, ix_se, ix, bcastDimensions), + rewriter.create(loc, iy_se, iy, bcastDimensions), + bcastDimensions); + Value w_ne = rewriter.create( + loc, + rewriter.create(loc, ix, ix_sw, bcastDimensions), + rewriter.create(loc, iy_sw, iy, bcastDimensions), + bcastDimensions); + Value w_sw = rewriter.create( + loc, + rewriter.create(loc, ix_ne, ix, bcastDimensions), + rewriter.create(loc, iy, iy_ne, bcastDimensions), + bcastDimensions); + Value w_se = rewriter.create( + loc, + rewriter.create(loc, ix, ix_nw, bcastDimensions), + rewriter.create(loc, iy, iy_nw, bcastDimensions), + bcastDimensions); + + Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, + oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy); + + // summand_nw + summand_ne + summand_sw + summand_se + Value sum = rewriter.create(loc, summand_nw, summand_ne); + sum = rewriter.create(loc, sum, summand_sw); + sum = rewriter.create(loc, sum, summand_se); + rewriter.replaceOp(op, sum); + } else if (interpolationMode == 1) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_round = rewriter.create(loc, ix); + Value iy_round = rewriter.create(loc, iy); + Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round); + Value summand = + getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH, + oW, iH, iW, Nidx, Cidx, outTy, elemTy); + rewriter.replaceOp(op, summand); + } + return success(); +} + void mlir::torch::torch_to_stablehlo:: populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -957,6 +1364,7 @@ void mlir::torch::torch_to_stablehlo:: INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenGridSamplerOp); #undef INSERT_ATENOP_PATTERN #define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ea1e33b6f..33dd2c082 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1080,6 +1080,10 @@ STABLEHLO_PASS_SET = { "GeIntModule_basic", "GeluBackwardModule_basic", "GluStaticModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", "GtFloatIntModule_basic", "GtIntModule_basic", "IndexTensorMultiIndexStaticModule_basic",