From eb8f56aeb790df1091d6844c0f5dafa1fc1b6051 Mon Sep 17 00:00:00 2001 From: Zhekun Zhang <32320144+zhekunz2@users.noreply.github.com> Date: Wed, 24 May 2023 11:13:57 -0700 Subject: [PATCH] [Stablehlo] Add `AtenIndexTensor` StableHlo support (#2107) * Add AtenIndexTensor StableHlo support * clean up * Empty commit, trigger test * try to debug hanging test * fix segfulat * fix bad include --------- Co-authored-by: zhekun.zhang --- e2e_testing/xfail_sets.py | 2 + .../TorchToStablehlo/GatherScatter.cpp | 157 +++++++++++++++++- 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 0155ac529..7f6b45077 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -456,6 +456,8 @@ STABLEHLO_PASS_SET = { "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", "IndexSelectNegativeDimModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorMultiIndexStaticModule_basic", "LayerNormLastDimModule_basic", "LayerNormModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 0118a8a59..c2dc9561f 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -15,12 +15,13 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; @@ -375,6 +376,159 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenIndexTensorOp +// Convert AtenIndexTensorOp to StableHlo::GatherOp +// Step 1: broadcast indices to the same shape +// Step 2: reshape broadcasted indices to have extra last dimension and concat +// Step 3: Create StableHlo::GatherOp with input tensor and indices +// +// Example: +// Input: [[1, 2, 3], +// [4, 5, 6], +// [7, 8, 9]] +// Indices[0]: [[0, 0, 0], +// [2, 2, 0]] +// Indices[1]: [[2], +// [1]] +// Step 1: +// Indices[0]: [[0, 0, 0], +// [2, 2, 0]] +// Indices[1]: [[2, 2, 2], +// [1, 1, 1]] +// Step 2: +// Indices: [[[0, 2], [0, 2], [0, 2]], +// [[2, 1], [2, 1], [0, 1]]] +// Step 3: +// Output: [[3, 3, 3], +// [8, 8, 2]] +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + auto inputTensorType = input.getType().dyn_cast(); + // Check input is a tensor type. + if (!inputTensorType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + Value indexList = op.getIndices(); + SmallVector indicesTorchType; + if (!getListConstructElements(indexList, indicesTorchType)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + + auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + indicesTorchType); + + // Step 1: broadcast indices tensors + int maxRank = -1; + SmallVector indicesShape; + SmallVector expandShape; + SmallVector concatShape; + // concat index tensor into to indices tensor for concat + for (size_t i = 0; i < indexTensors.size(); i++) { + auto indexTensor = indexTensors[i]; + auto indexTorchTensor = indicesTorchType[i]; + // TODO: add support for none index input + if (indexTorchTensor.getType().isa()) + return rewriter.notifyMatchFailure( + op, "Only list ranked tensor types index are supported"); + auto indexTensorType = indexTensor.getType().cast(); + for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { + if (size == kUnknownSize) + return rewriter.notifyMatchFailure(op, "Dynamic index support TBD"); + } + maxRank = std::max(maxRank, (int)indexTensorType.getRank()); + } + + RankedTensorType resultType = + getTypeConverter()->convertType(op.getType()).cast(); + SmallVector refinedResultShape = + makeShapeTorchCompatible(resultType.getShape()); + for (int64_t size : refinedResultShape) { + if (size == kUnknownSize) + return rewriter.notifyMatchFailure(op, "Dynamic index support TBD"); + } + for (int i = 0; i < maxRank; i++) { + indicesShape.push_back(refinedResultShape[i]); + expandShape.push_back(refinedResultShape[i]); + concatShape.push_back(refinedResultShape[i]); + } + if (indexTensors.size() > 1) { + expandShape.push_back(1); + concatShape.push_back(indexTensors.size()); + } + + SmallVector broadcastedIndices; + Type indexElemTy = + indexTensors[0].getType().cast().getElementType(); + RankedTensorType bcastIndexType = + RankedTensorType::get(indicesShape, indexElemTy); + for (auto indexTensor : indexTensors) { + Value bcastVal = + hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType); + if (indexTensors.size() > 1) { + RankedTensorType reshapeType = + RankedTensorType::get(expandShape, indexElemTy); + bcastVal = + rewriter.create(loc, reshapeType, bcastVal); + } + broadcastedIndices.push_back(bcastVal); + } + + // Step 2: concat index tensors + Value finalIndexTensor = broadcastedIndices[0]; + if (broadcastedIndices.size() > 1) { + RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy); + finalIndexTensor = rewriter.create( + loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1); + } + + // Step 3: create stablehlo::GatherOp + RankedTensorType finalIndexTy = + finalIndexTensor.getType().cast(); + int64_t indicesRank = finalIndexTy.getRank(); + int64_t numIndicesDim = broadcastedIndices.size(); + int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank; + + SmallVector offsetDims; + SmallVector collapsedDims; + SmallVector startIndexMap; + for (int64_t i = 0; i < numIndicesDim; ++i) { + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } + for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) { + if (numIndicesDim > 1) { + offsetDims.push_back(i + indicesRank - 1 - numIndicesDim); + } else { + offsetDims.push_back(i + indicesRank - numIndicesDim); + } + } + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedDims, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + SmallVector sliceSizes; + auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape()); + for (int64_t i = 0; i < inputTensorType.getRank(); ++i) { + if (i < numIndicesDim) { + sliceSizes.push_back(1); + } else { + sliceSizes.push_back(inputShape[i]); + } + } + + rewriter.replaceOpWithNewOp( + op, resultType, input, finalIndexTensor, dimsAttr, + rewriter.getI64TensorAttr(sliceSizes)); + return success(); +} + void mlir::torch::torch_to_stablehlo:: populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -388,5 +542,6 @@ void mlir::torch::torch_to_stablehlo:: INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorOp); #undef INSERT_ATENOP_PATTERN }