[Stablehlo] Add `AtenIndexTensor` StableHlo support (#2107)

* Add AtenIndexTensor StableHlo support

* clean up

* Empty commit, trigger test

* try to debug hanging test

* fix segfulat

* fix bad include

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/2154/merge
Zhekun Zhang 2023-05-24 11:13:57 -07:00 committed by GitHub
parent a426363b7d
commit eb8f56aeb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 158 additions and 1 deletions

View File

@ -456,6 +456,8 @@ STABLEHLO_PASS_SET = {
"IndexSelectWholeDimensionModule_basic", "IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic", "IndexSelectWholeTensorModule_basic",
"IndexSelectNegativeDimModule_basic", "IndexSelectNegativeDimModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LayerNormLastDimModule_basic", "LayerNormLastDimModule_basic",
"LayerNormModule_basic", "LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic", "LayerNormNormalizeOverAllDimsModule_basic",

View File

@ -15,12 +15,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
@ -375,6 +376,159 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
return success(); return success();
} }
// AtenIndexTensorOp
// Convert AtenIndexTensorOp to StableHlo::GatherOp
// Step 1: broadcast indices to the same shape
// Step 2: reshape broadcasted indices to have extra last dimension and concat
// Step 3: Create StableHlo::GatherOp with input tensor and indices
//
// Example:
// Input: [[1, 2, 3],
// [4, 5, 6],
// [7, 8, 9]]
// Indices[0]: [[0, 0, 0],
// [2, 2, 0]]
// Indices[1]: [[2],
// [1]]
// Step 1:
// Indices[0]: [[0, 0, 0],
// [2, 2, 0]]
// Indices[1]: [[2, 2, 2],
// [1, 1, 1]]
// Step 2:
// Indices: [[[0, 2], [0, 2], [0, 2]],
// [[2, 1], [2, 1], [0, 1]]]
// Step 3:
// Output: [[3, 3, 3],
// [8, 8, 2]]
template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
indicesTorchType);
// Step 1: broadcast indices tensors
int maxRank = -1;
SmallVector<int64_t> indicesShape;
SmallVector<int64_t> expandShape;
SmallVector<int64_t> concatShape;
// concat index tensor into to indices tensor for concat
for (size_t i = 0; i < indexTensors.size(); i++) {
auto indexTensor = indexTensors[i];
auto indexTorchTensor = indicesTorchType[i];
// TODO: add support for none index input
if (indexTorchTensor.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported");
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
if (size == kUnknownSize)
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
}
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
}
RankedTensorType resultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
SmallVector<int64_t> refinedResultShape =
makeShapeTorchCompatible(resultType.getShape());
for (int64_t size : refinedResultShape) {
if (size == kUnknownSize)
return rewriter.notifyMatchFailure(op, "Dynamic index support TBD");
}
for (int i = 0; i < maxRank; i++) {
indicesShape.push_back(refinedResultShape[i]);
expandShape.push_back(refinedResultShape[i]);
concatShape.push_back(refinedResultShape[i]);
}
if (indexTensors.size() > 1) {
expandShape.push_back(1);
concatShape.push_back(indexTensors.size());
}
SmallVector<Value> broadcastedIndices;
Type indexElemTy =
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
RankedTensorType bcastIndexType =
RankedTensorType::get(indicesShape, indexElemTy);
for (auto indexTensor : indexTensors) {
Value bcastVal =
hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType);
if (indexTensors.size() > 1) {
RankedTensorType reshapeType =
RankedTensorType::get(expandShape, indexElemTy);
bcastVal =
rewriter.create<stablehlo::ReshapeOp>(loc, reshapeType, bcastVal);
}
broadcastedIndices.push_back(bcastVal);
}
// Step 2: concat index tensors
Value finalIndexTensor = broadcastedIndices[0];
if (broadcastedIndices.size() > 1) {
RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy);
finalIndexTensor = rewriter.create<stablehlo::ConcatenateOp>(
loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1);
}
// Step 3: create stablehlo::GatherOp
RankedTensorType finalIndexTy =
finalIndexTensor.getType().cast<RankedTensorType>();
int64_t indicesRank = finalIndexTy.getRank();
int64_t numIndicesDim = broadcastedIndices.size();
int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank;
SmallVector<int64_t> offsetDims;
SmallVector<int64_t> collapsedDims;
SmallVector<int64_t> startIndexMap;
for (int64_t i = 0; i < numIndicesDim; ++i) {
collapsedDims.push_back(i);
startIndexMap.push_back(i);
}
for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) {
if (numIndicesDim > 1) {
offsetDims.push_back(i + indicesRank - 1 - numIndicesDim);
} else {
offsetDims.push_back(i + indicesRank - numIndicesDim);
}
}
auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get(
rewriter.getContext(),
/*offsetDims=*/offsetDims,
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/indexVecDim);
SmallVector<int64_t> sliceSizes;
auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape());
for (int64_t i = 0; i < inputTensorType.getRank(); ++i) {
if (i < numIndicesDim) {
sliceSizes.push_back(1);
} else {
sliceSizes.push_back(inputShape[i]);
}
}
rewriter.replaceOpWithNewOp<stablehlo::GatherOp>(
op, resultType, input, finalIndexTensor, dimsAttr,
rewriter.getI64TensorAttr(sliceSizes));
return success();
}
void mlir::torch::torch_to_stablehlo:: void mlir::torch::torch_to_stablehlo::
populateGatherScatterOpPatternsAndLegality( populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
@ -388,5 +542,6 @@ void mlir::torch::torch_to_stablehlo::
INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
INSERT_ATENOP_PATTERN(AtenIndexTensorOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
} }