//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; using namespace mlir::torch::torch_to_stablehlo; namespace { // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, Value index, Value dimSize) { auto loc = op->getLoc(); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); // To normalize index into range [-dimSize, dimSize] // index = min(max(-dimSize, index), dimSize) auto negDimSize = rewriter.create(loc, zero, dimSize); index = rewriter.create(loc, negDimSize, index); index = rewriter.create(loc, dimSize, index); auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); auto indexPositive = rewriter.create( loc, arith::CmpIPredicate::sge, index, zero); // get positive index: (index >=0) ? index: index + dimSize return rewriter.create(loc, indexPositive, index, dimSizePlusIndex); } Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, Type outTy, Value input, Value startIndex, Value endIndex, Value step, size_t dimIndex, ArrayRef dimSizes, size_t dimSizeIndexBits) { auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(intType, 0)); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); SmallVector startIndices; SmallVector endIndices; SmallVector strides; auto inputTy = input.getType().dyn_cast(); size_t rank = inputTy.getRank(); startIndices.reserve(rank); endIndices.reserve(rank); strides.reserve(rank); auto endIndexIsZero = rewriter.create( loc, arith::CmpIPredicate::eq, endIndex, zero); endIndex = rewriter.create(loc, endIndexIsZero, dimSizes[dimIndex], endIndex); for (size_t r = 0; r < rank; ++r) { if (r == dimIndex) { startIndices.push_back(startIndex); endIndices.push_back(endIndex); strides.push_back(step); } else { startIndices.push_back(zero); endIndices.push_back(dimSizes[r]); strides.push_back(one); } } auto startTensor = rewriter.create(loc, startIndices).getResult(); auto endTensor = rewriter.create(loc, endIndices).getResult(); auto stridesTensor = rewriter.create(loc, strides).getResult(); return rewriter.create( loc, outTy, input, startTensor, endTensor, stridesTensor); } // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, Type outTy, Value input, std::optional startIndexOpt, std::optional endIndexOpt, std::optional stepOpt, int64_t dim, size_t dimSizeIndexBits) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); dim = (dim + rank) % rank; Value dimSize = rewriter.create( loc, rewriter.getI64Type(), rewriter.create(loc, input, dim)); Value normStartIndex = startIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) : rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); Value normEndIndex = endIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) : dimSize; Value step = stepOpt ? *stepOpt : rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); if (dimSizeIndexBits == 32) { Type intType = rewriter.getIntegerType(dimSizeIndexBits); normStartIndex = rewriter.create(loc, intType, normStartIndex); normEndIndex = rewriter.create(loc, intType, normEndIndex); step = rewriter.create(loc, intType, step); } FailureOr> dimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex, normEndIndex, step, dim, dimSizes, dimSizeIndexBits); } // This defines a template to construct ops whose legalizations are // specialized. template class ConvertAtenViewOp : public ConvertAtenOp { public: using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rankType = adaptor.getSelf().getType().template dyn_cast(); if (!rankType) return op.emitError("Only ranked tensor types are currently supported"); SmallVector dimSizes; if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { return op.emitError("Dims size must be a list of Scalar"); } auto loc = op.getLoc(); auto newRank = dimSizes.size(); if (newRank == 0 || rankType.getRank() == 0) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.getSelf()); return success(); } std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { dSize = rewriter.create(loc, dSize).getResult(); return dSize; }); const auto &options = ConvertAtenOp::getOptions(); Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); if (options.dimSizeIndexBits == 32) { // The i64 calculation is much slower than i32 on some devices, such as // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are // unlikely to exceed the range of i32(4GiB) std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { // dimSize: cast i64 -> i32 dSize = rewriter.create(loc, intType, dSize); return dSize; }); } Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (auto d : dimSizes) { numel = rewriter.create(loc, numel, d); } numel = rewriter.create(loc, rewriter.getIndexType(), numel); if (dimSizes.size() == 0) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.getSelf()); return success(); } Value stablehloShape = rewriter.create(loc, dimSizes); Value computedShape = rewriter.create( loc, stablehloShape.getType(), numel, stablehloShape); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.getSelf(), computedShape); return success(); } bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &dimSizes) const; }; template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &dimSizes) const { return getListConstructElements(adaptor.getSize(), dimSizes); } template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, SmallVector &dimSizes) const { return getListConstructElements(adaptor.getShape(), dimSizes); } } // namespace template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); auto outTy = getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); int64_t inputRank = selfTy.getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto getOptionalVal = [&](Value val) -> std::optional { if (val.getType().isa()) { return std::nullopt; } else { return val; } }; std::optional start = getOptionalVal(adaptor.getStart()); std::optional end = getOptionalVal(adaptor.getEnd()); std::optional step = getOptionalVal(adaptor.getStep()); FailureOr sliceInfo = getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim, options.dimSizeIndexBits); if (failed(sliceInfo)) return op.emitError("can not create a dynmaic slice"); auto slice = *sliceInfo; rewriter.replaceOp(op, slice); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); auto rank = selfTy.getRank(); if (rank == 0) return rewriter.notifyMatchFailure( op, "The rank of tensor must be greater than 0"); SmallVector dims; dims.reserve(rank); for (int r = 0; r < rank; ++r) { auto dSize = selfTy.getShape()[r]; if (dSize == ShapedType::kDynamic) return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed can't be unknown"); if (dSize != 1) dims.push_back(r); } if (dims.size() == 0) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto stablehloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); auto rank = selfTy.getRank(); if (rank == 0) return rewriter.notifyMatchFailure( op, "the rank of tensor must be greater than 0"); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); dim = toPositiveDim(dim, rank); if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); if (selfTy.getShape()[dim] != 1) { if (selfTy.getShape()[dim] == ShapedType::kDynamic) return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed is can't be unknown"); rewriter.replaceOp(op, adaptor.getSelf()); return success(); } SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); if (dims.size() == 0) { rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto stablehloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenUnsqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto selfType = adaptor.getSelf().getType().dyn_cast(); if (!selfType) { return op.emitError("only tensor types are currently supported"); } int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim}, options.dimSizeIndexBits); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); rewriter.replaceOp(op, *unsqzTensorInfo); return success(); } void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenSqueezeOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN }