//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_mhlo; namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value input, Value indices, int64_t axis, size_t dimSizeIndexBits) { auto loc = op->getLoc(); Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); // sliceSizes auto inputRankTy = input.getType().dyn_cast(); auto inputRank = inputRankTy.getRank(); SmallVector sliceSizes; sliceSizes.reserve(inputRank); for (int64_t r = 0; r < inputRank; ++r) { if (r == axis) { sliceSizes.push_back(one); } else { sliceSizes.push_back(rewriter.create( loc, intType, rewriter.create(loc, input, r))); } } auto sliceSizesTensor = rewriter.create(loc, sliceSizes); // offsetDims SmallVector offsetDims; offsetDims.reserve(inputRank); for (int64_t r = 0; r < axis; ++r) { offsetDims.push_back(r); } auto indicesRankTy = indices.getType().dyn_cast(); auto indicesRank = indicesRankTy.getRank(); for (int64_t r = axis + 1; r < inputRank; ++r) { offsetDims.push_back(r + indicesRank - 1); } // collapsedSliceDims SmallVector collapsedSliceDims(1, axis); // startIndexMap SmallVector startIndexMap(1, axis); // indexVecDim int64_t indexVecDim = indicesRank; auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); // outputShape = input.shape[:axis] + indices.shape + // input.shape[axis + 1:] auto inputShape = inputRankTy.getShape(); auto indicesShape = indicesRankTy.getShape(); SmallVector outputShape(inputShape.begin(), inputShape.begin() + axis); outputShape.insert(outputShape.end(), indicesShape.begin(), indicesShape.end()); outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, inputShape.end()); // create output tensor type auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); return rewriter .create(loc, outputTy, input, indices, sliceSizesTensor, dimsAttr) .getResult(); } } // namespace // Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) // – If specified, the entries at padding_idx do not contribute to the gradient; // therefore, the embedding vector at padding_idx is not updated during training, // i.e. it remains as a fixed “pad”. // scale_grad_by_freq (boolean, optional) // – If given, this will scale gradients by the inverse of frequency of the // words in the mini-batch. Default False. // sparse (bool, optional) // – If True, gradient w.r.t. weight matrix will be a sparse tensor. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto weight = adaptor.getWeight(); auto weightTy = weight.getType().cast(); if (!weightTy) return op.emitError("only ranked tensor types are supported"); int64_t padding_idx; if (!matchPattern(op.getPaddingIdx(), m_TorchConstantInt(&padding_idx))) return rewriter.notifyMatchFailure( op, "only constant padding_idx is currently supported"); bool scale_grad_by_freq; if (!matchPattern(op.getScaleGradByFreq(), m_TorchConstantBool(&scale_grad_by_freq))) return rewriter.notifyMatchFailure( op, "only constant scale_grad_by_freq is currently supported"); if (scale_grad_by_freq) return rewriter.notifyMatchFailure( op, "scale gradients is currently not supported"); bool sparse; if (!matchPattern(op.getSparse(), m_TorchConstantBool(&sparse))) return rewriter.notifyMatchFailure( op, "only constant sparse is currently supported"); if (sparse) return rewriter.notifyMatchFailure( op, "sparse gradients is currently not supported"); Value output = gatherTensorAlongSingleAxis( rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); Value output = gatherTensorAlongSingleAxis( rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } // AtenGatherOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); Value input = adaptor.getSelf(); Value index = adaptor.getIndex(); auto inputType = input.getType().cast(); auto indexType = index.getType().cast(); auto indexElemType = indexType.getElementType(); if (indexType.getRank() != inputType.getRank()) { return op.emitError("`index` and `input` param should have the same rank"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { return rewriter.notifyMatchFailure( op, "only constant int `dim` param supported"); } dim = toPositiveDim(dim, inputType.getRank()); if (!isValidDim(dim, inputType.getRank())) { return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); } bool sparseGrad = false; if (!matchPattern(op.getSparseGrad(), m_TorchConstantBool(&sparseGrad))) { return rewriter.notifyMatchFailure( op, "only constant boolean `sparse_grad` param supported"); } auto options = getOptions(); auto indexShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); auto toConcatIndexShape = rewriter.create(loc, toConcatIndexShapeValueVec); auto indexShape = indexType.getShape(); SmallVector toConcatIndexShapeVec(indexShape.begin(), indexShape.end()); toConcatIndexShapeVec.push_back(1); RankedTensorType toConcatIndexType = RankedTensorType::get(toConcatIndexShapeVec, indexElemType); SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { toConcat.push_back(rewriter.create( loc, toConcatIndexType, index, toConcatIndexShape)); } else { toConcat.push_back(rewriter.create( loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } auto gatherIndicies = rewriter.create( loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); int64_t indexVecDim = inputType.getRank(); SmallVector collapsedDims; SmallVector startIndexMap; for (int64_t i = 0; i < inputType.getRank(); ++i) { collapsedDims.push_back(i); startIndexMap.push_back(i); } auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); return success(); } void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); #undef INSERT_ATENOP_PATTERN }