mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Add `AtenIndexTensor` StableHlo support (#2107)
* Add AtenIndexTensor StableHlo support * clean up * Empty commit, trigger test * try to debug hanging test * fix segfulat * fix bad include --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/2154/merge
parent
a426363b7d
commit
eb8f56aeb7
|
@ -456,6 +456,8 @@ STABLEHLO_PASS_SET = {
|
|||
"IndexSelectWholeDimensionModule_basic",
|
||||
"IndexSelectWholeTensorModule_basic",
|
||||
"IndexSelectNegativeDimModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"LayerNormLastDimModule_basic",
|
||||
"LayerNormModule_basic",
|
||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -375,6 +376,159 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
// AtenIndexTensorOp
|
||||
// Convert AtenIndexTensorOp to StableHlo::GatherOp
|
||||
// Step 1: broadcast indices to the same shape
|
||||
// Step 2: reshape broadcasted indices to have extra last dimension and concat
|
||||
// Step 3: Create StableHlo::GatherOp with input tensor and indices
|
||||
//
|
||||
// Example:
|
||||
// Input: [[1, 2, 3],
|
||||
// [4, 5, 6],
|
||||
// [7, 8, 9]]
|
||||
// Indices[0]: [[0, 0, 0],
|
||||
// [2, 2, 0]]
|
||||
// Indices[1]: [[2],
|
||||
// [1]]
|
||||
// Step 1:
|
||||
// Indices[0]: [[0, 0, 0],
|
||||
// [2, 2, 0]]
|
||||
// Indices[1]: [[2, 2, 2],
|
||||
// [1, 1, 1]]
|
||||
// Step 2:
|
||||
// Indices: [[[0, 2], [0, 2], [0, 2]],
|
||||
// [[2, 1], [2, 1], [0, 1]]]
|
||||
// Step 3:
|
||||
// Output: [[3, 3, 3],
|
||||
// [8, 8, 2]]
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||||
AtenIndexTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
Value input = adaptor.getSelf();
|
||||
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
|
||||
// Check input is a tensor type.
|
||||
if (!inputTensorType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only tensor types input are currently supported");
|
||||
Value indexList = op.getIndices();
|
||||
SmallVector<Value> indicesTorchType;
|
||||
if (!getListConstructElements(indexList, indicesTorchType))
|
||||
return op.emitError(
|
||||
"unimplemented: the tensor list is not from list construct");
|
||||
|
||||
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||||
indicesTorchType);
|
||||
|
||||
// Step 1: broadcast indices tensors
|
||||
int maxRank = -1;
|
||||
SmallVector<int64_t> indicesShape;
|
||||
SmallVector<int64_t> expandShape;
|
||||
SmallVector<int64_t> concatShape;
|
||||
// concat index tensor into to indices tensor for concat
|
||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||
auto indexTensor = indexTensors[i];
|
||||
auto indexTorchTensor = indicesTorchType[i];
|
||||
// TODO: add support for none index input
|
||||
if (indexTorchTensor.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
|
||||
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||||
if (size == kUnknownSize)
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||||
}
|
||||
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||||
}
|
||||
|
||||
RankedTensorType resultType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||
SmallVector<int64_t> refinedResultShape =
|
||||
makeShapeTorchCompatible(resultType.getShape());
|
||||
for (int64_t size : refinedResultShape) {
|
||||
if (size == kUnknownSize)
|
||||
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||||
}
|
||||
for (int i = 0; i < maxRank; i++) {
|
||||
indicesShape.push_back(refinedResultShape[i]);
|
||||
expandShape.push_back(refinedResultShape[i]);
|
||||
concatShape.push_back(refinedResultShape[i]);
|
||||
}
|
||||
if (indexTensors.size() > 1) {
|
||||
expandShape.push_back(1);
|
||||
concatShape.push_back(indexTensors.size());
|
||||
}
|
||||
|
||||
SmallVector<Value> broadcastedIndices;
|
||||
Type indexElemTy =
|
||||
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
|
||||
RankedTensorType bcastIndexType =
|
||||
RankedTensorType::get(indicesShape, indexElemTy);
|
||||
for (auto indexTensor : indexTensors) {
|
||||
Value bcastVal =
|
||||
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
|
||||
if (indexTensors.size() > 1) {
|
||||
RankedTensorType reshapeType =
|
||||
RankedTensorType::get(expandShape, indexElemTy);
|
||||
bcastVal =
|
||||
rewriter.create<stablehlo::ReshapeOp>(loc, reshapeType, bcastVal);
|
||||
}
|
||||
broadcastedIndices.push_back(bcastVal);
|
||||
}
|
||||
|
||||
// Step 2: concat index tensors
|
||||
Value finalIndexTensor = broadcastedIndices[0];
|
||||
if (broadcastedIndices.size() > 1) {
|
||||
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
|
||||
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
|
||||
loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1);
|
||||
}
|
||||
|
||||
// Step 3: create stablehlo::GatherOp
|
||||
RankedTensorType finalIndexTy =
|
||||
finalIndexTensor.getType().cast<RankedTensorType>();
|
||||
int64_t indicesRank = finalIndexTy.getRank();
|
||||
int64_t numIndicesDim = broadcastedIndices.size();
|
||||
int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank;
|
||||
|
||||
SmallVector<int64_t> offsetDims;
|
||||
SmallVector<int64_t> collapsedDims;
|
||||
SmallVector<int64_t> startIndexMap;
|
||||
for (int64_t i = 0; i < numIndicesDim; ++i) {
|
||||
collapsedDims.push_back(i);
|
||||
startIndexMap.push_back(i);
|
||||
}
|
||||
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
|
||||
if (numIndicesDim > 1) {
|
||||
offsetDims.push_back(i + indicesRank - 1 - numIndicesDim);
|
||||
} else {
|
||||
offsetDims.push_back(i + indicesRank - numIndicesDim);
|
||||
}
|
||||
}
|
||||
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*offsetDims=*/offsetDims,
|
||||
/*collapsedSliceDims=*/collapsedDims,
|
||||
/*startIndexMap=*/startIndexMap,
|
||||
/*indexVecDim=*/indexVecDim);
|
||||
|
||||
SmallVector<int64_t> sliceSizes;
|
||||
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
|
||||
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
|
||||
if (i < numIndicesDim) {
|
||||
sliceSizes.push_back(1);
|
||||
} else {
|
||||
sliceSizes.push_back(inputShape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||
op, resultType, input, finalIndexTensor, dimsAttr,
|
||||
rewriter.getI64TensorAttr(sliceSizes));
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::torch::torch_to_stablehlo::
|
||||
populateGatherScatterOpPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
|
@ -388,5 +542,6 @@ void mlir::torch::torch_to_stablehlo::
|
|||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue