//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 static constexpr size_t kMhloDimSizeBits = 32; #else static constexpr size_t kMhloDimSizeBits = 64; #endif namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, Value input, Value indices, int64_t axis) { auto loc = op->getLoc(); Type intType = rewriter.getIntegerType(kMhloDimSizeBits); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); // sliceSizes auto inputRankTy = input.getType().dyn_cast(); auto inputRank = inputRankTy.getRank(); SmallVector sliceSizes; sliceSizes.reserve(inputRank); for (int64_t r = 0; r < inputRank; ++r) { if (r == axis) { sliceSizes.push_back(one); } else { sliceSizes.push_back(rewriter.create( loc, intType, rewriter.create(loc, input, r))); } } auto sliceSizesTensor = rewriter.create(loc, sliceSizes); // offsetDims SmallVector offsetDims; offsetDims.reserve(inputRank); for (int64_t r = 0; r < axis; ++r) { offsetDims.push_back(r); } auto indicesRankTy = indices.getType().dyn_cast(); auto indicesRank = indicesRankTy.getRank(); for (int64_t r = axis + 1; r < inputRank; ++r) { offsetDims.push_back(r + indicesRank - 1); } // collapsedSliceDims SmallVector collapsedSliceDims(1, axis); // startIndexMap SmallVector startIndexMap(1, axis); // indexVecDim int64_t indexVecDim = indicesRank; auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); // outputShape = input.shape[:axis] + indices.shape + // input.shape[axis + 1:] auto inputShape = inputRankTy.getShape(); auto indicesShape = indicesRankTy.getShape(); SmallVector outputShape(inputShape.begin(), inputShape.begin() + axis); outputShape.insert(outputShape.end(), indicesShape.begin(), indicesShape.end()); outputShape.insert(outputShape.end(), inputShape.begin() + axis + 1, inputShape.end()); // create output tensor type auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); return rewriter .create(loc, outputTy, input, indices, sliceSizesTensor, dimsAttr) .getResult(); } template class ConvertAtenOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; // Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) // – If specified, the entries at padding_idx do not contribute to the gradient; // therefore, the embedding vector at padding_idx is not updated during training, // i.e. it remains as a fixed “pad”. // scale_grad_by_freq (boolean, optional) // – If given, this will scale gradients by the inverse of frequency of the // words in the mini-batch. Default False. // sparse (bool, optional) // – If True, gradient w.r.t. weight matrix will be a sparse tensor. template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenEmbeddingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto weight = adaptor.weight(); auto weightTy = weight.getType().template cast(); if (!weightTy) return op.emitError("only ranked tensor types are supported"); int64_t padding_idx; if (!matchPattern(op.padding_idx(), m_TorchConstantInt(&padding_idx))) return rewriter.notifyMatchFailure( op, "only constant padding_idx is currently supported"); bool scale_grad_by_freq; if (!matchPattern(op.scale_grad_by_freq(), m_TorchConstantBool(&scale_grad_by_freq))) return rewriter.notifyMatchFailure( op, "only constant scale_grad_by_freq is currently supported"); if (scale_grad_by_freq) return rewriter.notifyMatchFailure( op, "scale gradients is currently not supported"); bool sparse; if (!matchPattern(op.sparse(), m_TorchConstantBool(&sparse))) return rewriter.notifyMatchFailure( op, "only constant sparse is currently supported"); if (sparse) return rewriter.notifyMatchFailure( op, "sparse gradients is currently not supported"); Value output = gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); Value output = gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } } // namespace void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); #undef INSERT_ATENOP_PATTERN }