mirror of https://github.com/llvm/torch-mlir
[Stablehlo] Add `AtenIndexTensor` StableHlo support (#2107)
* Add AtenIndexTensor StableHlo support * clean up * Empty commit, trigger test * try to debug hanging test * fix segfulat * fix bad include --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/2154/merge
parent
a426363b7d
commit
eb8f56aeb7
|
@ -456,6 +456,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"IndexSelectWholeDimensionModule_basic",
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
"IndexSelectWholeTensorModule_basic",
|
"IndexSelectWholeTensorModule_basic",
|
||||||
"IndexSelectNegativeDimModule_basic",
|
"IndexSelectNegativeDimModule_basic",
|
||||||
|
"IndexTensorStaticModule_basic",
|
||||||
|
"IndexTensorMultiIndexStaticModule_basic",
|
||||||
"LayerNormLastDimModule_basic",
|
"LayerNormLastDimModule_basic",
|
||||||
"LayerNormModule_basic",
|
"LayerNormModule_basic",
|
||||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -375,6 +376,159 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AtenIndexTensorOp
|
||||||
|
// Convert AtenIndexTensorOp to StableHlo::GatherOp
|
||||||
|
// Step 1: broadcast indices to the same shape
|
||||||
|
// Step 2: reshape broadcasted indices to have extra last dimension and concat
|
||||||
|
// Step 3: Create StableHlo::GatherOp with input tensor and indices
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// Input: [[1, 2, 3],
|
||||||
|
// [4, 5, 6],
|
||||||
|
// [7, 8, 9]]
|
||||||
|
// Indices[0]: [[0, 0, 0],
|
||||||
|
// [2, 2, 0]]
|
||||||
|
// Indices[1]: [[2],
|
||||||
|
// [1]]
|
||||||
|
// Step 1:
|
||||||
|
// Indices[0]: [[0, 0, 0],
|
||||||
|
// [2, 2, 0]]
|
||||||
|
// Indices[1]: [[2, 2, 2],
|
||||||
|
// [1, 1, 1]]
|
||||||
|
// Step 2:
|
||||||
|
// Indices: [[[0, 2], [0, 2], [0, 2]],
|
||||||
|
// [[2, 1], [2, 1], [0, 1]]]
|
||||||
|
// Step 3:
|
||||||
|
// Output: [[3, 3, 3],
|
||||||
|
// [8, 8, 2]]
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||||||
|
AtenIndexTensorOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
|
||||||
|
// Check input is a tensor type.
|
||||||
|
if (!inputTensorType)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only tensor types input are currently supported");
|
||||||
|
Value indexList = op.getIndices();
|
||||||
|
SmallVector<Value> indicesTorchType;
|
||||||
|
if (!getListConstructElements(indexList, indicesTorchType))
|
||||||
|
return op.emitError(
|
||||||
|
"unimplemented: the tensor list is not from list construct");
|
||||||
|
|
||||||
|
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||||||
|
indicesTorchType);
|
||||||
|
|
||||||
|
// Step 1: broadcast indices tensors
|
||||||
|
int maxRank = -1;
|
||||||
|
SmallVector<int64_t> indicesShape;
|
||||||
|
SmallVector<int64_t> expandShape;
|
||||||
|
SmallVector<int64_t> concatShape;
|
||||||
|
// concat index tensor into to indices tensor for concat
|
||||||
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||||
|
auto indexTensor = indexTensors[i];
|
||||||
|
auto indexTorchTensor = indicesTorchType[i];
|
||||||
|
// TODO: add support for none index input
|
||||||
|
if (indexTorchTensor.getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only list ranked tensor types index are supported");
|
||||||
|
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
|
||||||
|
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||||||
|
if (size == kUnknownSize)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||||||
|
}
|
||||||
|
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||||||
|
}
|
||||||
|
|
||||||
|
RankedTensorType resultType =
|
||||||
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
||||||
|
SmallVector<int64_t> refinedResultShape =
|
||||||
|
makeShapeTorchCompatible(resultType.getShape());
|
||||||
|
for (int64_t size : refinedResultShape) {
|
||||||
|
if (size == kUnknownSize)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < maxRank; i++) {
|
||||||
|
indicesShape.push_back(refinedResultShape[i]);
|
||||||
|
expandShape.push_back(refinedResultShape[i]);
|
||||||
|
concatShape.push_back(refinedResultShape[i]);
|
||||||
|
}
|
||||||
|
if (indexTensors.size() > 1) {
|
||||||
|
expandShape.push_back(1);
|
||||||
|
concatShape.push_back(indexTensors.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> broadcastedIndices;
|
||||||
|
Type indexElemTy =
|
||||||
|
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
|
||||||
|
RankedTensorType bcastIndexType =
|
||||||
|
RankedTensorType::get(indicesShape, indexElemTy);
|
||||||
|
for (auto indexTensor : indexTensors) {
|
||||||
|
Value bcastVal =
|
||||||
|
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
|
||||||
|
if (indexTensors.size() > 1) {
|
||||||
|
RankedTensorType reshapeType =
|
||||||
|
RankedTensorType::get(expandShape, indexElemTy);
|
||||||
|
bcastVal =
|
||||||
|
rewriter.create<stablehlo::ReshapeOp>(loc, reshapeType, bcastVal);
|
||||||
|
}
|
||||||
|
broadcastedIndices.push_back(bcastVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: concat index tensors
|
||||||
|
Value finalIndexTensor = broadcastedIndices[0];
|
||||||
|
if (broadcastedIndices.size() > 1) {
|
||||||
|
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
|
||||||
|
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
|
||||||
|
loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: create stablehlo::GatherOp
|
||||||
|
RankedTensorType finalIndexTy =
|
||||||
|
finalIndexTensor.getType().cast<RankedTensorType>();
|
||||||
|
int64_t indicesRank = finalIndexTy.getRank();
|
||||||
|
int64_t numIndicesDim = broadcastedIndices.size();
|
||||||
|
int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank;
|
||||||
|
|
||||||
|
SmallVector<int64_t> offsetDims;
|
||||||
|
SmallVector<int64_t> collapsedDims;
|
||||||
|
SmallVector<int64_t> startIndexMap;
|
||||||
|
for (int64_t i = 0; i < numIndicesDim; ++i) {
|
||||||
|
collapsedDims.push_back(i);
|
||||||
|
startIndexMap.push_back(i);
|
||||||
|
}
|
||||||
|
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
|
||||||
|
if (numIndicesDim > 1) {
|
||||||
|
offsetDims.push_back(i + indicesRank - 1 - numIndicesDim);
|
||||||
|
} else {
|
||||||
|
offsetDims.push_back(i + indicesRank - numIndicesDim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
|
||||||
|
rewriter.getContext(),
|
||||||
|
/*offsetDims=*/offsetDims,
|
||||||
|
/*collapsedSliceDims=*/collapsedDims,
|
||||||
|
/*startIndexMap=*/startIndexMap,
|
||||||
|
/*indexVecDim=*/indexVecDim);
|
||||||
|
|
||||||
|
SmallVector<int64_t> sliceSizes;
|
||||||
|
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
|
||||||
|
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
|
||||||
|
if (i < numIndicesDim) {
|
||||||
|
sliceSizes.push_back(1);
|
||||||
|
} else {
|
||||||
|
sliceSizes.push_back(inputShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
|
||||||
|
op, resultType, input, finalIndexTensor, dimsAttr,
|
||||||
|
rewriter.getI64TensorAttr(sliceSizes));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::torch::torch_to_stablehlo::
|
void mlir::torch::torch_to_stablehlo::
|
||||||
populateGatherScatterOpPatternsAndLegality(
|
populateGatherScatterOpPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
|
@ -388,5 +542,6 @@ void mlir::torch::torch_to_stablehlo::
|
||||||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
INSERT_ATENOP_PATTERN(AtenGatherOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue