mirror of https://github.com/llvm/torch-mlir
548 lines
21 KiB
C++
548 lines
21 KiB
C++
//===----------------------------------------------------------------------===//
|
||
//
|
||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
// See https://llvm.org/LICENSE.txt for license information.
|
||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
// Also available under a BSD-style license. See LICENSE.
|
||
//
|
||
//===----------------------------------------------------------------------===//
|
||
|
||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||
|
||
#include "../PassDetail.h"
|
||
#include "PopulatePatterns.h"
|
||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||
#include "stablehlo/dialect/StablehloOps.h"
|
||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||
|
||
using namespace mlir;
|
||
using namespace mlir::torch;
|
||
using namespace mlir::torch::Torch;
|
||
using namespace mlir::torch::torch_to_stablehlo;
|
||
|
||
namespace {
|
||
Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||
Value input, Value indices, int64_t axis,
|
||
size_t dimSizeIndexBits) {
|
||
auto loc = op->getLoc();
|
||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||
Value one = rewriter.create<arith::ConstantOp>(
|
||
loc, rewriter.getIntegerAttr(intType, 1));
|
||
|
||
// sliceSizes
|
||
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>();
|
||
auto inputRank = inputRankTy.getRank();
|
||
SmallVector<Value, 4> sliceSizes;
|
||
sliceSizes.reserve(inputRank);
|
||
for (int64_t r = 0; r < inputRank; ++r) {
|
||
if (r == axis) {
|
||
sliceSizes.push_back(one);
|
||
} else {
|
||
sliceSizes.push_back(rewriter.create<arith::IndexCastOp>(
|
||
loc, intType, rewriter.create<tensor::DimOp>(loc, input, r)));
|
||
}
|
||
}
|
||
auto sliceSizesTensor =
|
||
rewriter.create<tensor::FromElementsOp>(loc, sliceSizes);
|
||
|
||
// offsetDims
|
||
SmallVector<int64_t, 4> offsetDims;
|
||
offsetDims.reserve(inputRank);
|
||
for (int64_t r = 0; r < axis; ++r) {
|
||
offsetDims.push_back(r);
|
||
}
|
||
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>();
|
||
auto indicesRank = indicesRankTy.getRank();
|
||
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
||
offsetDims.push_back(r + indicesRank - 1);
|
||
}
|
||
|
||
// collapsedSliceDims
|
||
SmallVector<int64_t, 4> collapsedSliceDims(1, axis);
|
||
// startIndexMap
|
||
SmallVector<int64_t, 4> startIndexMap(1, axis);
|
||
// indexVecDim
|
||
int64_t indexVecDim = indicesRank;
|
||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||
rewriter.getContext(),
|
||
/*offsetDims=*/offsetDims,
|
||
/*collapsedSliceDims=*/collapsedSliceDims,
|
||
/*startIndexMap=*/startIndexMap,
|
||
/*indexVecDim=*/indexVecDim);
|
||
|
||
// outputShape = input.shape[:axis] + indices.shape +
|
||
// input.shape[axis + 1:]
|
||
auto inputShape = inputRankTy.getShape();
|
||
auto indicesShape = indicesRankTy.getShape();
|
||
SmallVector<int64_t, 4> outputShape(inputShape.begin(),
|
||
inputShape.begin() + axis);
|
||
outputShape.insert(outputShape.end(), indicesShape.begin(),
|
||
indicesShape.end());
|
||
outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1,
|
||
inputShape.end());
|
||
|
||
// create output tensor type
|
||
auto outputTy =
|
||
RankedTensorType::get(outputShape, inputRankTy.getElementType());
|
||
return rewriter
|
||
.create<stablehlo::DynamicGatherOp>(loc, outputTy, input, indices,
|
||
sliceSizesTensor, dimsAttr)
|
||
.getResult();
|
||
}
|
||
|
||
template <typename OpTy, typename OpAdaptor>
|
||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter,
|
||
SmallVector<Value> &resultShape,
|
||
SmallVector<Value> &offsets,
|
||
SmallVector<Value> &strides) {
|
||
Location loc = op.getLoc();
|
||
auto input = adaptor.getSelf();
|
||
RankedTensorType inputType =
|
||
input.getType().template cast<RankedTensorType>();
|
||
|
||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||
|
||
int64_t dim;
|
||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||
return op->emitError("unimplemented: dim is not constant");
|
||
|
||
int64_t inputRank = inputType.getRank();
|
||
dim = toPositiveDim(dim, inputRank);
|
||
if (!isValidDim(dim, inputRank))
|
||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||
|
||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||
Value dimSize = inputShape[dim];
|
||
|
||
Value torchTypeStart = op.getStart();
|
||
Value torchTypeEnd = op.getEnd();
|
||
Value builtinTypeStart = adaptor.getStart();
|
||
Value builtinTypeEnd = adaptor.getEnd();
|
||
|
||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||
torchTypeEnd.getType().isa<OptionalType>())
|
||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||
|
||
int64_t step;
|
||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||
if (!op.getStep().getType().template isa<Torch::NoneType>())
|
||
return op->emitError("unimplemented: step is not constant");
|
||
step = 1;
|
||
}
|
||
|
||
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||
builtinTypeStart, zero, dimSize);
|
||
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||
dimSize, dimSize);
|
||
|
||
// end >= start ? end : start
|
||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||
loc, arith::CmpIPredicate::sge, end, start);
|
||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||
|
||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||
resultShape = getTensorSizes(rewriter, loc, input);
|
||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||
resultShape[dim] = resultSize;
|
||
|
||
strides.resize(inputType.getRank(), one);
|
||
offsets.resize(inputType.getRank(), zero);
|
||
|
||
offsets[dim] = start;
|
||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||
return success();
|
||
}
|
||
} // namespace
|
||
|
||
// Ref:
|
||
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
|
||
// padding_idx (int, optional)
|
||
// – If specified, the entries at padding_idx do not contribute to the
|
||
// gradient; therefore, the embedding vector at padding_idx is not updated
|
||
// during training, i.e. it remains as a fixed “pad”.
|
||
// scale_grad_by_freq (boolean, optional)
|
||
// – If given, this will scale gradients by the inverse of frequency of the
|
||
// words in the mini-batch. Default False.
|
||
// sparse (bool, optional)
|
||
// – If True, gradient w.r.t. weight matrix will be a sparse tensor.
|
||
template <>
|
||
LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||
AtenEmbeddingOp op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter) const {
|
||
auto weight = adaptor.getWeight();
|
||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||
if (!weightTy)
|
||
return op.emitError("only ranked tensor types are supported");
|
||
|
||
int64_t padding_idx;
|
||
if (!matchPattern(op.getPaddingIdx(), m_TorchConstantInt(&padding_idx)))
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant padding_idx is currently supported");
|
||
|
||
bool scale_grad_by_freq;
|
||
if (!matchPattern(op.getScaleGradByFreq(),
|
||
m_TorchConstantBool(&scale_grad_by_freq)))
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant scale_grad_by_freq is currently supported");
|
||
if (scale_grad_by_freq)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "scale gradients is currently not supported");
|
||
bool sparse;
|
||
if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse)))
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant sparse is currently supported");
|
||
if (sparse)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "sparse gradients is currently not supported");
|
||
|
||
Value output = gatherTensorAlongSingleAxis(
|
||
rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits);
|
||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||
op, getTypeConverter()->convertType(op.getType()), output);
|
||
|
||
return success();
|
||
}
|
||
|
||
template <>
|
||
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter) const {
|
||
auto self = adaptor.getSelf();
|
||
auto selfTy = self.getType().cast<RankedTensorType>();
|
||
if (!selfTy)
|
||
return op.emitError("only ranked tensor types are supported");
|
||
int64_t dim;
|
||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant dim is currently supported");
|
||
int64_t inputRank = selfTy.getRank();
|
||
dim = toPositiveDim(dim, inputRank);
|
||
if (!isValidDim(dim, inputRank))
|
||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||
|
||
Value output = gatherTensorAlongSingleAxis(
|
||
rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits);
|
||
|
||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(
|
||
op, getTypeConverter()->convertType(op.getType()), output);
|
||
|
||
return success();
|
||
}
|
||
|
||
// AtenGatherOp
|
||
template <>
|
||
LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||
AtenGatherOp op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter) const {
|
||
Location loc = op->getLoc();
|
||
Value input = adaptor.getSelf();
|
||
Value index = adaptor.getIndex();
|
||
auto inputType = input.getType().cast<RankedTensorType>();
|
||
auto indexType = index.getType().cast<RankedTensorType>();
|
||
auto indexElemType = indexType.getElementType();
|
||
|
||
if (indexType.getRank() != inputType.getRank()) {
|
||
return op.emitError("`index` and `input` param should have the same rank");
|
||
}
|
||
int64_t dim;
|
||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant int `dim` param supported");
|
||
}
|
||
dim = toPositiveDim(dim, inputType.getRank());
|
||
if (!isValidDim(dim, inputType.getRank())) {
|
||
return rewriter.notifyMatchFailure(op, "invalid `dim` param detected");
|
||
}
|
||
|
||
bool sparseGrad = false;
|
||
if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad))) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "only constant boolean `sparse_grad` param supported");
|
||
}
|
||
|
||
auto options = getOptions();
|
||
auto indexShapeInfo =
|
||
hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits);
|
||
if (failed(indexShapeInfo)) {
|
||
return rewriter.notifyMatchFailure(
|
||
op, "failed to get dim sizes of `index` param");
|
||
}
|
||
auto intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||
auto one = rewriter.create<arith::ConstantOp>(
|
||
loc, rewriter.getIntegerAttr(intType, 1));
|
||
auto toConcatIndexShapeValueVec = *indexShapeInfo;
|
||
toConcatIndexShapeValueVec.push_back(one);
|
||
auto toConcatIndexShape =
|
||
rewriter.create<tensor::FromElementsOp>(loc, toConcatIndexShapeValueVec);
|
||
|
||
auto indexShape = indexType.getShape();
|
||
SmallVector<int64_t> toConcatIndexShapeVec(indexShape.begin(),
|
||
indexShape.end());
|
||
toConcatIndexShapeVec.push_back(1);
|
||
RankedTensorType toConcatIndexType =
|
||
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);
|
||
|
||
SmallVector<Value> toConcat;
|
||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||
if (i == dim) {
|
||
toConcat.push_back(rewriter.create<stablehlo::DynamicReshapeOp>(
|
||
loc, toConcatIndexType, index, toConcatIndexShape));
|
||
} else {
|
||
toConcat.push_back(rewriter.create<stablehlo::DynamicIotaOp>(
|
||
loc, toConcatIndexType, toConcatIndexShape,
|
||
rewriter.getI64IntegerAttr(i)));
|
||
}
|
||
}
|
||
auto gatherIndicies = rewriter.create<stablehlo::ConcatenateOp>(
|
||
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));
|
||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);
|
||
|
||
int64_t indexVecDim = inputType.getRank();
|
||
SmallVector<int64_t> collapsedDims;
|
||
SmallVector<int64_t> startIndexMap;
|
||
for (int64_t i = 0; i < inputType.getRank(); ++i) {
|
||
collapsedDims.push_back(i);
|
||
startIndexMap.push_back(i);
|
||
}
|
||
|
||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||
rewriter.getContext(),
|
||
/*offsetDims=*/{},
|
||
/*collapsedSliceDims=*/collapsedDims,
|
||
/*startIndexMap=*/startIndexMap,
|
||
/*indexVecDim=*/indexVecDim);
|
||
|
||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||
op, input, gatherIndicies, dimsAttr,
|
||
rewriter.getI64TensorAttr(sliceSizes));
|
||
return success();
|
||
}
|
||
|
||
// AtenSliceScatterOp
|
||
template <>
|
||
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||
AtenSliceScatterOp op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter) const {
|
||
|
||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||
return failure();
|
||
|
||
Location loc = op.getLoc();
|
||
TypeConverter *typeConverter = getTypeConverter();
|
||
|
||
auto input = adaptor.getSelf();
|
||
|
||
RankedTensorType resultType =
|
||
typeConverter->convertType(op->getResult(0).getType())
|
||
.cast<RankedTensorType>();
|
||
|
||
SmallVector<Value> resultShape;
|
||
SmallVector<Value> offsets;
|
||
SmallVector<Value> strides;
|
||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||
AtenSliceScatterOpAdaptor>(
|
||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||
return failure();
|
||
}
|
||
|
||
Value src = adaptor.getSrc();
|
||
auto srcType = src.getType().cast<RankedTensorType>();
|
||
int64_t srcRank = srcType.getRank();
|
||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||
auto abstractSrcType = RankedTensorType::get(
|
||
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
|
||
Value abstractSrc =
|
||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||
|
||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||
|
||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||
|
||
return success();
|
||
}
|
||
|
||
// AtenIndexTensorOp
|
||
// Convert AtenIndexTensorOp to StableHlo::GatherOp
|
||
// Step 1: broadcast indices to the same shape
|
||
// Step 2: reshape broadcasted indices to have extra last dimension and concat
|
||
// Step 3: Create StableHlo::GatherOp with input tensor and indices
|
||
//
|
||
// Example:
|
||
// Input: [[1, 2, 3],
|
||
// [4, 5, 6],
|
||
// [7, 8, 9]]
|
||
// Indices[0]: [[0, 0, 0],
|
||
// [2, 2, 0]]
|
||
// Indices[1]: [[2],
|
||
// [1]]
|
||
// Step 1:
|
||
// Indices[0]: [[0, 0, 0],
|
||
// [2, 2, 0]]
|
||
// Indices[1]: [[2, 2, 2],
|
||
// [1, 1, 1]]
|
||
// Step 2:
|
||
// Indices: [[[0, 2], [0, 2], [0, 2]],
|
||
// [[2, 1], [2, 1], [0, 1]]]
|
||
// Step 3:
|
||
// Output: [[3, 3, 3],
|
||
// [8, 8, 2]]
|
||
template <>
|
||
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||
AtenIndexTensorOp op, OpAdaptor adaptor,
|
||
ConversionPatternRewriter &rewriter) const {
|
||
Location loc = op->getLoc();
|
||
Value input = adaptor.getSelf();
|
||
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
|
||
// Check input is a tensor type.
|
||
if (!inputTensorType)
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Only tensor types input are currently supported");
|
||
Value indexList = op.getIndices();
|
||
SmallVector<Value> indicesTorchType;
|
||
if (!getListConstructElements(indexList, indicesTorchType))
|
||
return op.emitError(
|
||
"unimplemented: the tensor list is not from list construct");
|
||
|
||
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||
indicesTorchType);
|
||
|
||
// Step 1: broadcast indices tensors
|
||
int maxRank = -1;
|
||
SmallVector<int64_t> indicesShape;
|
||
SmallVector<int64_t> expandShape;
|
||
SmallVector<int64_t> concatShape;
|
||
// concat index tensor into to indices tensor for concat
|
||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||
auto indexTensor = indexTensors[i];
|
||
auto indexTorchTensor = indicesTorchType[i];
|
||
// TODO: add support for none index input
|
||
if (indexTorchTensor.getType().isa<Torch::NoneType>())
|
||
return rewriter.notifyMatchFailure(
|
||
op, "Only list ranked tensor types index are supported");
|
||
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
|
||
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||
if (size == kUnknownSize)
|
||
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||
}
|
||
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||
}
|
||
|
||
RankedTensorType resultType =
|
||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||
SmallVector<int64_t> refinedResultShape =
|
||
makeShapeTorchCompatible(resultType.getShape());
|
||
for (int64_t size : refinedResultShape) {
|
||
if (size == kUnknownSize)
|
||
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||
}
|
||
for (int i = 0; i < maxRank; i++) {
|
||
indicesShape.push_back(refinedResultShape[i]);
|
||
expandShape.push_back(refinedResultShape[i]);
|
||
concatShape.push_back(refinedResultShape[i]);
|
||
}
|
||
if (indexTensors.size() > 1) {
|
||
expandShape.push_back(1);
|
||
concatShape.push_back(indexTensors.size());
|
||
}
|
||
|
||
SmallVector<Value> broadcastedIndices;
|
||
Type indexElemTy =
|
||
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
|
||
RankedTensorType bcastIndexType =
|
||
RankedTensorType::get(indicesShape, indexElemTy);
|
||
for (auto indexTensor : indexTensors) {
|
||
Value bcastVal =
|
||
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
|
||
if (indexTensors.size() > 1) {
|
||
RankedTensorType reshapeType =
|
||
RankedTensorType::get(expandShape, indexElemTy);
|
||
bcastVal =
|
||
rewriter.create<stablehlo::ReshapeOp>(loc, reshapeType, bcastVal);
|
||
}
|
||
broadcastedIndices.push_back(bcastVal);
|
||
}
|
||
|
||
// Step 2: concat index tensors
|
||
Value finalIndexTensor = broadcastedIndices[0];
|
||
if (broadcastedIndices.size() > 1) {
|
||
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
|
||
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
|
||
loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1);
|
||
}
|
||
|
||
// Step 3: create stablehlo::GatherOp
|
||
RankedTensorType finalIndexTy =
|
||
finalIndexTensor.getType().cast<RankedTensorType>();
|
||
int64_t indicesRank = finalIndexTy.getRank();
|
||
int64_t numIndicesDim = broadcastedIndices.size();
|
||
int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank;
|
||
|
||
SmallVector<int64_t> offsetDims;
|
||
SmallVector<int64_t> collapsedDims;
|
||
SmallVector<int64_t> startIndexMap;
|
||
for (int64_t i = 0; i < numIndicesDim; ++i) {
|
||
collapsedDims.push_back(i);
|
||
startIndexMap.push_back(i);
|
||
}
|
||
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
|
||
if (numIndicesDim > 1) {
|
||
offsetDims.push_back(i + indicesRank - 1 - numIndicesDim);
|
||
} else {
|
||
offsetDims.push_back(i + indicesRank - numIndicesDim);
|
||
}
|
||
}
|
||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||
rewriter.getContext(),
|
||
/*offsetDims=*/offsetDims,
|
||
/*collapsedSliceDims=*/collapsedDims,
|
||
/*startIndexMap=*/startIndexMap,
|
||
/*indexVecDim=*/indexVecDim);
|
||
|
||
SmallVector<int64_t> sliceSizes;
|
||
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
|
||
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
|
||
if (i < numIndicesDim) {
|
||
sliceSizes.push_back(1);
|
||
} else {
|
||
sliceSizes.push_back(inputShape[i]);
|
||
}
|
||
}
|
||
|
||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||
op, resultType, input, finalIndexTensor, dimsAttr,
|
||
rewriter.getI64TensorAttr(sliceSizes));
|
||
return success();
|
||
}
|
||
|
||
void mlir::torch::torch_to_stablehlo::
|
||
populateGatherScatterOpPatternsAndLegality(
|
||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
||
MLIRContext *context = patterns.getContext();
|
||
|
||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||
target.addIllegalOp<AtenOp>(); \
|
||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
|
||
#undef INSERT_ATENOP_PATTERN
|
||
}
|