mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Add lowering of GridSampler Op (#3084)
Inspired by PyTorch decompositions.py.
See
ec58f1f74e/torch/_decomp/decompositions.py (L3923-L4086)
Only support paddingMode=0 or 1 and interpolationMode=0 or 1
pull/3430/head
parent
72837fbb3d
commit
431d98b405
|
@ -13,7 +13,9 @@
|
|||
#include "PopulatePatterns.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/ChloOps.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -900,7 +902,6 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
||||
updateWindowDims.push_back(i);
|
||||
}
|
||||
|
||||
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*updateWindowDims=*/updateWindowDims,
|
||||
|
@ -941,6 +942,412 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenGridSamplerOp
|
||||
// See
|
||||
// https://github.com/pytorch/pytorch/blob/ec58f1f74ebcec744d2ab90ad34abd09c1018e92/torch/_decomp/decompositions.py#L3923-L4086
|
||||
namespace {
|
||||
template <typename T>
|
||||
static Value getConstantLike(OpBuilder &b, Location loc, T constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
auto getAttr = [&]() -> Attribute {
|
||||
if (isa<mlir::IntegerType>(ty))
|
||||
return b.getIntegerAttr(ty, constant);
|
||||
if (isa<mlir::FloatType>(ty))
|
||||
return b.getFloatAttr(ty, constant);
|
||||
if (auto complexTy = dyn_cast<mlir::ComplexType>(ty))
|
||||
return complex::NumberAttr::get(complexTy, constant, 0);
|
||||
llvm_unreachable("unhandled element type");
|
||||
};
|
||||
return b.create<mlir::chlo::ConstantLikeOp>(loc, cast<TypedAttr>(getAttr()),
|
||||
val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
ArrayRef<T> values, ArrayRef<int64_t> shape,
|
||||
Type ty) {
|
||||
Location loc = op->getLoc();
|
||||
RankedTensorType valueType = RankedTensorType::get(shape, ty);
|
||||
auto valueAttr = DenseElementsAttr::get(valueType, values);
|
||||
return rewriter.create<stablehlo::ConstantOp>(loc, valueType, valueAttr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Value getConstScalarTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, T value, Type ty) {
|
||||
return getConstTensor(rewriter, op, ArrayRef<T>{value}, {}, ty);
|
||||
}
|
||||
|
||||
// Helper function to lower AtenGridSamplerOp.
|
||||
static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value coords, int64_t size, Type elemTy,
|
||||
bool alignCorners) {
|
||||
Location loc = op->getLoc();
|
||||
APFloat pointFive(cast<mlir::FloatType>(elemTy).getFloatSemantics(), "0.5");
|
||||
APFloat sizeFloat =
|
||||
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), size);
|
||||
APFloat one = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 1);
|
||||
APFloat zero = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
|
||||
|
||||
// double mul = alignCorners ? (size * 0.5 - 0.5) : (size * 0.5);
|
||||
// double ofs = size * 0.5 - 0.5;
|
||||
APFloat mul =
|
||||
alignCorners ? sizeFloat * pointFive - pointFive : sizeFloat * pointFive;
|
||||
APFloat ofs = sizeFloat * pointFive - pointFive;
|
||||
Value constMul = getConstScalarTensor(rewriter, op, mul, elemTy);
|
||||
Value constOfs = getConstScalarTensor(rewriter, op, ofs, elemTy);
|
||||
|
||||
// use chlo::BroadcastMulOp to multiply constMul with coords.
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
Value mulResult = rewriter.create<chlo::BroadcastMulOp>(loc, coords, constMul,
|
||||
bcastDimensions);
|
||||
// use chlo::BroadcastAddOp to add constOfs to mulResult.
|
||||
Value result = rewriter.create<chlo::BroadcastAddOp>(loc, mulResult, constOfs,
|
||||
bcastDimensions);
|
||||
return result;
|
||||
}
|
||||
|
||||
static Value computeCoordinates(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value coords, int64_t size,
|
||||
Type elemTy, int64_t padding_mode) {
|
||||
// TODO: add support for padding_mode 1 and 2.
|
||||
return coords;
|
||||
}
|
||||
|
||||
static Value computeSourceIndex(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value coords, int64_t size,
|
||||
Type elemTy, int64_t padding_mode,
|
||||
bool alignCorners) {
|
||||
Value coordsUn =
|
||||
unnormalize(rewriter, op, coords, size, elemTy, alignCorners);
|
||||
return computeCoordinates(rewriter, op, coordsUn, size, elemTy, padding_mode);
|
||||
}
|
||||
|
||||
// def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
|
||||
// return torch.logical_and(
|
||||
// 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys
|
||||
// < iH))
|
||||
// )
|
||||
static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value xs, Value ys, int64_t ih, int64_t iw,
|
||||
Type elemTy) {
|
||||
Location loc = op->getLoc();
|
||||
APFloat zeroFloat =
|
||||
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
|
||||
Value zero = getConstScalarTensor(rewriter, op, zeroFloat, elemTy);
|
||||
APFloat iwFloat =
|
||||
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), iw);
|
||||
APFloat ihFloat =
|
||||
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), ih);
|
||||
|
||||
Value iwFloatValue = getConstScalarTensor(rewriter, op, iwFloat, elemTy);
|
||||
Value ihFloatValue = getConstScalarTensor(rewriter, op, ihFloat, elemTy);
|
||||
|
||||
chlo::ComparisonTypeAttr compareTypeAttr = chlo::ComparisonTypeAttr::get(
|
||||
rewriter.getContext(), chlo::ComparisonType::FLOAT);
|
||||
chlo::ComparisonDirectionAttr compareLTAttr =
|
||||
chlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
chlo::ComparisonDirection::LT);
|
||||
chlo::ComparisonDirectionAttr compareGEAttr =
|
||||
chlo::ComparisonDirectionAttr::get(rewriter.getContext(),
|
||||
chlo::ComparisonDirection::GE);
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
Value cond1 = rewriter.create<chlo::BroadcastCompareOp>(
|
||||
loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr);
|
||||
Value cond2 = rewriter.create<chlo::BroadcastCompareOp>(
|
||||
loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr);
|
||||
Value cond3 = rewriter.create<chlo::BroadcastCompareOp>(
|
||||
loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr);
|
||||
Value cond4 = rewriter.create<chlo::BroadcastCompareOp>(
|
||||
loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr);
|
||||
Value cond5 =
|
||||
rewriter.create<chlo::BroadcastAndOp>(loc, cond1, cond2, bcastDimensions);
|
||||
Value cond6 =
|
||||
rewriter.create<chlo::BroadcastAndOp>(loc, cond3, cond4, bcastDimensions);
|
||||
return rewriter.create<chlo::BroadcastAndOp>(loc, cond5, cond6,
|
||||
bcastDimensions);
|
||||
}
|
||||
// def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
|
||||
// cond = in_bounds_cond(xs, ys)
|
||||
// # To clip to inside valid coordinates, we map the coordinates
|
||||
// # to (x, y) = (0, 0) and also set the weight to 0
|
||||
// # We also change the shape of the tensor to the appropriate one for
|
||||
// # broadcasting with N_idx, C_idx for the purposes of advanced
|
||||
// indexing c = C if _expand_grid else 1
|
||||
// return tuple(
|
||||
// torch.where(cond, t, 0).view(N, c, oH, oW)
|
||||
// for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
|
||||
// )
|
||||
SmallVector<Value> clip(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value xs, Value ys, Value ws, int64_t N, int64_t oH,
|
||||
int64_t oW, int64_t iH, int64_t iW, Type elemTy) {
|
||||
Location loc = op->getLoc();
|
||||
auto indexElemTy = rewriter.getI64Type();
|
||||
auto indexTy = RankedTensorType::get(mlir::ArrayRef<int64_t>{1}, indexElemTy);
|
||||
|
||||
Value zeroIntValue = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef<int64_t>{0}));
|
||||
|
||||
APFloat zeroAPFloat =
|
||||
APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
|
||||
Value zeroFloatValue =
|
||||
getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy);
|
||||
Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy);
|
||||
Value xsInt = rewriter.create<stablehlo::ConvertOp>(loc, xs, indexElemTy);
|
||||
Value ysInt = rewriter.create<stablehlo::ConvertOp>(loc, ys, indexElemTy);
|
||||
|
||||
Value selectXs = rewriter.create<chlo::BroadcastSelectOp>(
|
||||
loc, ArrayRef<Value>{cond, xsInt, zeroIntValue});
|
||||
Value selectYs = rewriter.create<chlo::BroadcastSelectOp>(
|
||||
loc, ArrayRef<Value>{cond, ysInt, zeroIntValue});
|
||||
Value selectWs = rewriter.create<chlo::BroadcastSelectOp>(
|
||||
loc, ArrayRef<Value>{cond, ws, zeroFloatValue});
|
||||
|
||||
SmallVector<int64_t> sizes = {N, 1, oH, oW};
|
||||
Value reshapedXs = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(sizes, indexElemTy), selectXs);
|
||||
Value reshapedYs = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(sizes, indexElemTy), selectYs);
|
||||
Value reshapedWs = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(sizes, elemTy), selectWs);
|
||||
return SmallVector<Value>{reshapedXs, reshapedYs, reshapedWs};
|
||||
}
|
||||
|
||||
Value getSummand(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value input, Value ix, Value iy, Value w, int64_t N,
|
||||
int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx,
|
||||
Value CIdx, RankedTensorType outType, Type elemTy) {
|
||||
Location loc = op->getLoc();
|
||||
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<Value> clipValues =
|
||||
clip(rewriter, op, ix, iy, w, N, oH, oW, iH, iW, elemTy);
|
||||
Value idxX = clipValues[0];
|
||||
Value idxY = clipValues[1];
|
||||
Value idxW = clipValues[2];
|
||||
SmallVector<Value> indexTensors{Nidx, CIdx, idxY, idxX};
|
||||
|
||||
int maxIndexRank = -1;
|
||||
auto gatherIndicesInfo =
|
||||
broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors,
|
||||
outType.getShape(), maxIndexRank);
|
||||
auto gatherIndices = *gatherIndicesInfo;
|
||||
int64_t numIndicesDim = indexTensors.size();
|
||||
int64_t indexVecDim = maxIndexRank;
|
||||
|
||||
SmallVector<int64_t> offsetDims;
|
||||
SmallVector<int64_t> collapsedDims;
|
||||
SmallVector<int64_t> startIndexMap;
|
||||
for (int64_t i = 0; i < numIndicesDim; ++i) {
|
||||
collapsedDims.push_back(i);
|
||||
startIndexMap.push_back(i);
|
||||
}
|
||||
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
|
||||
offsetDims.push_back(i + maxIndexRank - numIndicesDim);
|
||||
}
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/offsetDims,
|
||||
/*collapsedSliceDims=*/collapsedDims,
|
||||
/*operandBatchingDims=*/{},
|
||||
/*startIndicesBatchingDims=*/{},
|
||||
/*startIndexMap=*/startIndexMap,
|
||||
/*indexVecDim=*/indexVecDim);
|
||||
|
||||
SmallVector<int64_t> sliceSizes;
|
||||
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
|
||||
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
|
||||
if (i < numIndicesDim) {
|
||||
sliceSizes.push_back(1);
|
||||
} else {
|
||||
sliceSizes.push_back(inputShape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Value gather = rewriter.create<stablehlo::GatherOp>(
|
||||
loc, input, gatherIndices, dimsAttr,
|
||||
rewriter.getDenseI64ArrayAttr(sliceSizes));
|
||||
// use chlo::BroadcastMulOp to multiply idxW with gather.
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
return rewriter.create<chlo::BroadcastMulOp>(loc, gather, idxW,
|
||||
bcastDimensions);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenGridSamplerOp>::matchAndRewrite(
|
||||
AtenGridSamplerOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getInput();
|
||||
Value grid = adaptor.getGrid();
|
||||
|
||||
int64_t interpolationMode;
|
||||
if (!matchPattern(op.getInterpolationMode(),
|
||||
m_TorchConstantInt(&interpolationMode)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "interpolation_mode must be an integer constant");
|
||||
int64_t paddingMode;
|
||||
if (!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingMode)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "padding_mode must be an integer constant");
|
||||
|
||||
if (interpolationMode != 0 && interpolationMode != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support interpolation_mode = 0 (bilinear) or 1(nearest)");
|
||||
|
||||
if (paddingMode != 0)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support paddingMode = 0 (Zero)");
|
||||
|
||||
bool alignCorners = false;
|
||||
if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "alignCorners must be a boolean constant");
|
||||
|
||||
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
||||
RankedTensorType gridTy = cast<RankedTensorType>(grid.getType());
|
||||
RankedTensorType outTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
Type elemTy = inputTy.getElementType();
|
||||
if (inputTy.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(op, "input must be a 4D tensor");
|
||||
if (gridTy.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(op, "grid must be a 4D tensor");
|
||||
|
||||
auto inputSize = inputTy.getShape();
|
||||
auto gridSize = gridTy.getShape();
|
||||
int64_t N = inputSize[0];
|
||||
int64_t C = inputSize[1];
|
||||
int64_t iH = inputSize[2];
|
||||
int64_t iW = inputSize[3];
|
||||
int64_t oH = gridSize[1];
|
||||
int64_t oW = gridSize[2];
|
||||
// grid is a 4D tensor with shape (N, oH, oW, 2)
|
||||
|
||||
Type indexElemTy = rewriter.getI64Type();
|
||||
RankedTensorType indexTy =
|
||||
RankedTensorType::get(mlir::ArrayRef<int64_t>{1}, indexElemTy);
|
||||
Value constN = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, indexTy, DenseIntElementsAttr::get(indexTy, {N}));
|
||||
Value constC = rewriter.create<stablehlo::ConstantOp>(
|
||||
loc, indexTy, DenseIntElementsAttr::get(indexTy, {C}));
|
||||
APFloat one = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 1);
|
||||
APFloat zero = APFloat(cast<mlir::FloatType>(elemTy).getFloatSemantics(), 0);
|
||||
|
||||
Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy);
|
||||
|
||||
auto NidxFlatten = rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
loc, RankedTensorType::get(mlir::ArrayRef<int64_t>{N}, indexElemTy),
|
||||
constN, 0);
|
||||
auto CidxFlatten = rewriter.create<stablehlo::DynamicIotaOp>(
|
||||
loc, RankedTensorType::get(mlir::ArrayRef<int64_t>{C}, indexElemTy),
|
||||
constC, 0);
|
||||
|
||||
// Reshape NidxFlatten to 4D tensor (N, 1, 1, 1)
|
||||
auto NidxSizes = mlir::SmallVector<int64_t>{N, 1, 1, 1};
|
||||
auto Nidx = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten);
|
||||
|
||||
// Reshape CidxFlatten to 4D tensor (1, C, 1, 1)
|
||||
auto CidxSizes = mlir::SmallVector<int64_t>{1, C, 1, 1};
|
||||
auto Cidx = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten);
|
||||
|
||||
llvm::SmallVector<int64_t> stride(4, 1);
|
||||
auto gridX = rewriter.create<stablehlo::SliceOp>(
|
||||
loc,
|
||||
RankedTensorType::get(mlir::SmallVector<int64_t>{N, oH, oW, 1},
|
||||
gridTy.getElementType()),
|
||||
grid, mlir::SmallVector<int64_t>{0, 0, 0, 0},
|
||||
mlir::SmallVector<int64_t>{N, oH, oW, 1}, stride);
|
||||
auto gridY = rewriter.create<stablehlo::SliceOp>(
|
||||
loc,
|
||||
RankedTensorType::get(mlir::SmallVector<int64_t>{N, oH, oW, 1},
|
||||
gridTy.getElementType()),
|
||||
grid, mlir::SmallVector<int64_t>{0, 0, 0, 1},
|
||||
mlir::SmallVector<int64_t>{N, oH, oW, 2}, stride);
|
||||
// squeeze last dimension
|
||||
auto gridXshape = mlir::SmallVector<int64_t>{N, oH, oW};
|
||||
|
||||
auto gridXReshape = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX);
|
||||
auto gridYReshape = rewriter.create<stablehlo::ReshapeOp>(
|
||||
loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY);
|
||||
|
||||
if (interpolationMode == 0) {
|
||||
Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy,
|
||||
paddingMode, alignCorners);
|
||||
Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy,
|
||||
paddingMode, alignCorners);
|
||||
Value ix_nw = rewriter.create<stablehlo::FloorOp>(loc, ix);
|
||||
Value iy_nw = rewriter.create<stablehlo::FloorOp>(loc, iy);
|
||||
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
Value ix_ne = rewriter.create<chlo::BroadcastAddOp>(
|
||||
loc, ix_nw, constOneFloat, bcastDimensions);
|
||||
Value iy_ne = iy_nw;
|
||||
Value ix_sw = ix_nw;
|
||||
Value iy_sw = rewriter.create<chlo::BroadcastAddOp>(
|
||||
loc, iy_nw, constOneFloat, bcastDimensions);
|
||||
Value ix_se = ix_ne;
|
||||
Value iy_se = iy_sw;
|
||||
|
||||
// w_nw = (ix_se - ix) * (iy_se - iy)
|
||||
// w_ne = (ix - ix_sw) * (iy_sw - iy)
|
||||
// w_sw = (ix_ne - ix) * (iy - iy_ne)
|
||||
// w_se = (ix - ix_nw) * (iy - iy_nw)
|
||||
Value w_nw = rewriter.create<chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, ix_se, ix, bcastDimensions),
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, iy_se, iy, bcastDimensions),
|
||||
bcastDimensions);
|
||||
Value w_ne = rewriter.create<chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, ix, ix_sw, bcastDimensions),
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, iy_sw, iy, bcastDimensions),
|
||||
bcastDimensions);
|
||||
Value w_sw = rewriter.create<chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, ix_ne, ix, bcastDimensions),
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_ne, bcastDimensions),
|
||||
bcastDimensions);
|
||||
Value w_se = rewriter.create<chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, ix, ix_nw, bcastDimensions),
|
||||
rewriter.create<chlo::BroadcastSubOp>(loc, iy, iy_nw, bcastDimensions),
|
||||
bcastDimensions);
|
||||
|
||||
Value summand_nw = getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_ne = getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_sw = getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
Value summand_se = getSummand(rewriter, op, input, ix_se, iy_se, w_se, N,
|
||||
oH, oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
|
||||
// summand_nw + summand_ne + summand_sw + summand_se
|
||||
Value sum = rewriter.create<stablehlo::AddOp>(loc, summand_nw, summand_ne);
|
||||
sum = rewriter.create<stablehlo::AddOp>(loc, sum, summand_sw);
|
||||
sum = rewriter.create<stablehlo::AddOp>(loc, sum, summand_se);
|
||||
rewriter.replaceOp(op, sum);
|
||||
} else if (interpolationMode == 1) {
|
||||
Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy,
|
||||
paddingMode, alignCorners);
|
||||
Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy,
|
||||
paddingMode, alignCorners);
|
||||
Value ix_round = rewriter.create<stablehlo::RoundOp>(loc, ix);
|
||||
Value iy_round = rewriter.create<stablehlo::RoundOp>(loc, iy);
|
||||
Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round);
|
||||
Value summand =
|
||||
getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH,
|
||||
oW, iH, iW, Nidx, Cidx, outTy, elemTy);
|
||||
rewriter.replaceOp(op, summand);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::
|
||||
populateGatherScatterOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
@ -957,6 +1364,7 @@ void mlir::torch::torch_to_stablehlo::
|
|||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGridSamplerOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \
|
||||
|
|
|
@ -1080,6 +1080,10 @@ STABLEHLO_PASS_SET = {
|
|||
"GeIntModule_basic",
|
||||
"GeluBackwardModule_basic",
|
||||
"GluStaticModule_basic",
|
||||
"GridSamplerBasic1_basic",
|
||||
"GridSamplerBasic2_basic",
|
||||
"GridSamplerBasic3_basic",
|
||||
"GridSamplerBasic4_basic",
|
||||
"GtFloatIntModule_basic",
|
||||
"GtIntModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
|
|
Loading…
Reference in New Issue