//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value input, Value indices, int64_t axis, size_t dimSizeIndexBits) { auto loc = op->getLoc(); Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); // sliceSizes auto inputRankTy = input.getType().dyn_cast(); auto inputRank = inputRankTy.getRank(); SmallVector sliceSizes; sliceSizes.reserve(inputRank); for (int64_t r = 0; r < inputRank; ++r) { if (r == axis) { sliceSizes.push_back(one); } else { sliceSizes.push_back(rewriter.create( loc, intType, rewriter.create(loc, input, r))); } } auto sliceSizesTensor = rewriter.create(loc, sliceSizes); // offsetDims SmallVector offsetDims; offsetDims.reserve(inputRank); for (int64_t r = 0; r < axis; ++r) { offsetDims.push_back(r); } auto indicesRankTy = indices.getType().dyn_cast(); auto indicesRank = indicesRankTy.getRank(); for (int64_t r = axis + 1; r < inputRank; ++r) { offsetDims.push_back(r + indicesRank - 1); } // collapsedSliceDims SmallVector collapsedSliceDims(1, axis); // startIndexMap SmallVector startIndexMap(1, axis); // indexVecDim int64_t indexVecDim = indicesRank; auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); // outputShape = input.shape[:axis] + indices.shape + // input.shape[axis + 1:] auto inputShape = inputRankTy.getShape(); auto indicesShape = indicesRankTy.getShape(); SmallVector outputShape(inputShape.begin(), inputShape.begin() + axis); outputShape.insert(outputShape.end(), indicesShape.begin(), indicesShape.end()); outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, inputShape.end()); // create output tensor type auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); return rewriter .create(loc, outputTy, input, indices, sliceSizesTensor, dimsAttr) .getResult(); } template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { Location loc = op.getLoc(); auto input = adaptor.getSelf(); RankedTensorType inputType = input.getType().template cast(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); int64_t inputRank = inputType.getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector inputShape = getTensorSizes(rewriter, loc, input); Value dimSize = inputShape[dim]; Value torchTypeStart = op.getStart(); Value torchTypeEnd = op.getEnd(); Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); if (torchTypeStart.getType().isa() || torchTypeEnd.getType().isa()) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { if (!op.getStep().getType().template isa()) return op->emitError("unimplemented: step is not constant"); step = 1; } Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, dimSize, dimSize); // end >= start ? end : start Value endSgeStart = rewriter.create( loc, arith::CmpIPredicate::sge, end, start); end = rewriter.create(loc, endSgeStart, end, start); Value stepIndex = rewriter.create(loc, step); // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); Value len = rewriter.create(loc, end, start); Value resultSize = rewriter.create(loc, len, stepIndex); resultSize = rewriter.create(loc, resultSize, one); resultSize = rewriter.create(loc, resultSize, stepIndex); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; strides[dim] = rewriter.create(loc, strides[dim], stepIndex); return success(); } } // namespace // Ref: // https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) // – If specified, the entries at padding_idx do not contribute to the // gradient; therefore, the embedding vector at padding_idx is not updated // during training, i.e. it remains as a fixed “pad”. // scale_grad_by_freq (boolean, optional) // – If given, this will scale gradients by the inverse of frequency of the // words in the mini-batch. Default False. // sparse (bool, optional) // – If True, gradient w.r.t. weight matrix will be a sparse tensor. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto weight = adaptor.getWeight(); auto weightTy = weight.getType().cast(); if (!weightTy) return op.emitError("only ranked tensor types are supported"); int64_t padding_idx; if (!matchPattern(op.getPaddingIdx(), m_TorchConstantInt(&padding_idx))) return rewriter.notifyMatchFailure( op, "only constant padding_idx is currently supported"); bool scale_grad_by_freq; if (!matchPattern(op.getScaleGradByFreq(), m_TorchConstantBool(&scale_grad_by_freq))) return rewriter.notifyMatchFailure( op, "only constant scale_grad_by_freq is currently supported"); if (scale_grad_by_freq) return rewriter.notifyMatchFailure( op, "scale gradients is currently not supported"); bool sparse; if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse))) return rewriter.notifyMatchFailure( op, "only constant sparse is currently supported"); if (sparse) return rewriter.notifyMatchFailure( op, "sparse gradients is currently not supported"); Value output = gatherTensorAlongSingleAxis( rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); int64_t inputRank = selfTy.getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value output = gatherTensorAlongSingleAxis( rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } // AtenGatherOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value index = adaptor.getIndex(); auto inputType = input.getType().cast(); auto indexType = index.getType().cast(); auto indexElemType = indexType.getElementType(); if (indexType.getRank() != inputType.getRank()) { return op.emitError("`index` and `input` param should have the same rank"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure( op, "only constant int `dim` param supported"); } dim = toPositiveDim(dim, inputType.getRank()); if (!isValidDim(dim, inputType.getRank())) { return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); } bool sparseGrad = false; if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad))) { return rewriter.notifyMatchFailure( op, "only constant boolean `sparse_grad` param supported"); } auto options = getOptions(); auto indexShapeInfo = hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); auto toConcatIndexShape = rewriter.create(loc, toConcatIndexShapeValueVec); auto indexShape = indexType.getShape(); SmallVector toConcatIndexShapeVec(indexShape.begin(), indexShape.end()); toConcatIndexShapeVec.push_back(1); RankedTensorType toConcatIndexType = RankedTensorType::get(toConcatIndexShapeVec, indexElemType); SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { toConcat.push_back(rewriter.create( loc, toConcatIndexType, index, toConcatIndexShape)); } else { toConcat.push_back(rewriter.create( loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } auto gatherIndicies = rewriter.create( loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); int64_t indexVecDim = inputType.getRank(); SmallVector collapsedDims; SmallVector startIndexMap; for (int64_t i = 0; i < inputType.getRank(); ++i) { collapsedDims.push_back(i); startIndexMap.push_back(i); } auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); return success(); } // AtenSliceScatterOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); RankedTensorType resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); SmallVector resultShape; SmallVector offsets; SmallVector strides; if (failed(prepareArgumentsForSlicingOp( op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } Value src = adaptor.getSrc(); auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = rewriter.create(loc, abstractSrcType, src); Value result = rewriter.create( loc, abstractSrc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); } // AtenIndexTensorOp // Convert AtenIndexTensorOp to StableHlo::GatherOp // Step 1: broadcast indices to the same shape // Step 2: reshape broadcasted indices to have extra last dimension and concat // Step 3: Create StableHlo::GatherOp with input tensor and indices // // Example: // Input: [[1, 2, 3], // [4, 5, 6], // [7, 8, 9]] // Indices[0]: [[0, 0, 0], // [2, 2, 0]] // Indices[1]: [[2], // [1]] // Step 1: // Indices[0]: [[0, 0, 0], // [2, 2, 0]] // Indices[1]: [[2, 2, 2], // [1, 1, 1]] // Step 2: // Indices: [[[0, 2], [0, 2], [0, 2]], // [[2, 1], [2, 1], [0, 1]]] // Step 3: // Output: [[3, 3, 3], // [8, 8, 2]] template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); Value input = adaptor.getSelf(); auto inputTensorType = input.getType().dyn_cast(); // Check input is a tensor type. if (!inputTensorType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); Value indexList = op.getIndices(); SmallVector indicesTorchType; if (!getListConstructElements(indexList, indicesTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTorchType); // Step 1: broadcast indices tensors int maxRank = -1; SmallVector indicesShape; SmallVector expandShape; SmallVector concatShape; // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto indexTensor = indexTensors[i]; auto indexTorchTensor = indicesTorchType[i]; // TODO: add support for none index input if (indexTorchTensor.getType().isa()) return rewriter.notifyMatchFailure( op, "Only list ranked tensor types index are supported"); auto indexTensorType = indexTensor.getType().cast(); for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { if (size == kUnknownSize) return rewriter.notifyMatchFailure(op, "Dynamic index support TBD"); } maxRank = std::max(maxRank, (int)indexTensorType.getRank()); } RankedTensorType resultType = getTypeConverter()->convertType(op.getType()).cast(); SmallVector refinedResultShape = makeShapeTorchCompatible(resultType.getShape()); for (int64_t size : refinedResultShape) { if (size == kUnknownSize) return rewriter.notifyMatchFailure(op, "Dynamic index support TBD"); } for (int i = 0; i < maxRank; i++) { indicesShape.push_back(refinedResultShape[i]); expandShape.push_back(refinedResultShape[i]); concatShape.push_back(refinedResultShape[i]); } if (indexTensors.size() > 1) { expandShape.push_back(1); concatShape.push_back(indexTensors.size()); } SmallVector broadcastedIndices; Type indexElemTy = indexTensors[0].getType().cast().getElementType(); RankedTensorType bcastIndexType = RankedTensorType::get(indicesShape, indexElemTy); for (auto indexTensor : indexTensors) { Value bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType); if (indexTensors.size() > 1) { RankedTensorType reshapeType = RankedTensorType::get(expandShape, indexElemTy); bcastVal = rewriter.create(loc, reshapeType, bcastVal); } broadcastedIndices.push_back(bcastVal); } // Step 2: concat index tensors Value finalIndexTensor = broadcastedIndices[0]; if (broadcastedIndices.size() > 1) { RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy); finalIndexTensor = rewriter.create( loc, concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1); } // Step 3: create stablehlo::GatherOp RankedTensorType finalIndexTy = finalIndexTensor.getType().cast(); int64_t indicesRank = finalIndexTy.getRank(); int64_t numIndicesDim = broadcastedIndices.size(); int64_t indexVecDim = numIndicesDim > 1 ? indicesRank - 1 : indicesRank; SmallVector offsetDims; SmallVector collapsedDims; SmallVector startIndexMap; for (int64_t i = 0; i < numIndicesDim; ++i) { collapsedDims.push_back(i); startIndexMap.push_back(i); } for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) { if (numIndicesDim > 1) { offsetDims.push_back(i + indicesRank - 1 - numIndicesDim); } else { offsetDims.push_back(i + indicesRank - numIndicesDim); } } auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); SmallVector sliceSizes; auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape()); for (int64_t i = 0; i < inputTensorType.getRank(); ++i) { if (i < numIndicesDim) { sliceSizes.push_back(1); } else { sliceSizes.push_back(inputShape[i]); } } rewriter.replaceOpWithNewOp( op, resultType, input, finalIndexTensor, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); return success(); } void mlir::torch::torch_to_stablehlo:: populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenIndexTensorOp); #undef INSERT_ATENOP_PATTERN }