//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; namespace { // A dimension index from torch.dialect might outside the range [0, dimSize]. // The function is used to normalize the input index into the range. Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, Value index, Value dimSize) { auto loc = op->getLoc(); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); // To normalize index into range [-dimSize, dimSize] // index = min(max(-dimSize, index), dimSize) auto negDimSize = rewriter.create(loc, zero, dimSize); index = rewriter.create(loc, negDimSize, index); index = rewriter.create(loc, dimSize, index); auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); auto indexPositive = rewriter.create( loc, arith::CmpIPredicate::sge, index, zero); // get positive index: (index >=0) ? index: index + dimSize return rewriter.create(loc, indexPositive, index, dimSizePlusIndex); } Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, Value input, Value startIndex, Value endIndex, Value step, size_t dimIndex, ArrayRef dimSizes) { auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(intType, 0)); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); SmallVector startIndices; SmallVector endIndices; SmallVector strides; auto inputTy = input.getType().dyn_cast(); size_t rank = inputTy.getRank(); startIndices.reserve(rank); endIndices.reserve(rank); strides.reserve(rank); auto endIndexIsZero = rewriter.create( loc, arith::CmpIPredicate::eq, endIndex, zero); endIndex = rewriter.create(loc, endIndexIsZero, dimSizes[dimIndex], endIndex); for (size_t r = 0; r < rank; ++r) { if (r == dimIndex) { startIndices.push_back(startIndex); endIndices.push_back(endIndex); strides.push_back(step); } else { startIndices.push_back(zero); endIndices.push_back(dimSizes[r]); strides.push_back(one); } } auto startTensor = rewriter.create(loc, startIndices).getResult(); auto endTensor = rewriter.create(loc, endIndices).getResult(); auto stridesTensor = rewriter.create(loc, strides).getResult(); auto inputShape = inputTy.getShape(); SmallVector sliceShape(inputShape.begin(), inputShape.end()); sliceShape[dimIndex] = ShapedType::kDynamicSize; auto sliceoutputTy = RankedTensorType::get(sliceShape, inputTy.getElementType()); return rewriter.create( loc, sliceoutputTy, input, startTensor, endTensor, stridesTensor); } // Get a dynamic slice of the tensor from startIndex to endIndex with stride // step on the specifed dimension. The input startIndex(default to 0), // endIndex(default to dimSize), and step(default to 1) can be optional. FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input, llvm::Optional startIndexOpt, llvm::Optional endIndexOpt, llvm::Optional stepOpt, int64_t dim) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); dim = (dim + rank) % rank; Value dimSize = rewriter.create( loc, rewriter.getI64Type(), rewriter.create(loc, input, dim)); Value normStartIndex = startIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) : rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); Value normEndIndex = endIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) : dimSize; Value step = stepOpt ? *stepOpt : rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); normStartIndex = rewriter.create(loc, i32Type, normStartIndex); normEndIndex = rewriter.create(loc, i32Type, normEndIndex); step = rewriter.create(loc, i32Type, step); #endif FailureOr> dimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, input, normStartIndex, normEndIndex, step, dim, dimSizes); } template class ConvertAtenOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); auto getOptionalVal = [&](Value val) -> llvm::Optional { if (val.getType().isa()) { return llvm::None; } else { return val; } }; llvm::Optional start = getOptionalVal(adaptor.start()); llvm::Optional end = getOptionalVal(adaptor.end()); llvm::Optional step = getOptionalVal(adaptor.step()); FailureOr sliceInfo = getDynamicSlice(rewriter, op, self, start, end, step, dim); if (failed(sliceInfo)) return op.emitError("can not create a dynmaic slice"); auto slice = *sliceInfo; rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), slice); return success(); } // This defines a template to construct ops whose legalizations are // specialized. template class ConvertAtenViewOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite( AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto rankType = adaptor.self().getType().template dyn_cast(); if (!rankType) return op.emitError("Only ranked tensor types are currently supported"); SmallVector dimSizes; if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { return op.emitError("Dims size must be a list of Scalar"); } auto loc = op.getLoc(); auto newRank = dimSizes.size(); if (newRank == 0 || rankType.getRank() == 0) { rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.self()); return success(); } std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { dSize = rewriter.create(loc, dSize).getResult(); return dSize; }); #ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 // The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. // One can truncate from i64 to i32 since dimension sizes are unlikely to exceed // the range of i32(4GiB) std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { // dimSize: cast i64 -> i32 dSize = rewriter.create(loc, rewriter.getI32Type(), dSize); return dSize; }); #endif Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (auto d : dimSizes) { numel = rewriter.create(loc, numel, d); } numel = rewriter.create(loc, rewriter.getIndexType(), numel); Value mhloShape = rewriter.create(loc, dimSizes); Value computedShape = rewriter.create( loc, mhloShape.getType(), numel, mhloShape); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), adaptor.self(), computedShape); return success(); } bool getAtenViewOpSizes( AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter, SmallVector& dimSizes) const; }; template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter, SmallVector& dimSizes) const { return getListConstructElements(adaptor.size(), dimSizes); } template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter, SmallVector& dimSizes) const { return getListConstructElements(adaptor.shape(), dimSizes); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); auto rank = selfTy.getRank(); if (rank == 0) return rewriter.notifyMatchFailure( op, "The rank of tensor must be greater than 0"); SmallVector dims; dims.reserve(rank); for (int r = 0; r < rank; ++r) { auto dSize = selfTy.getShape()[r]; if (dSize == ShapedType::kDynamicSize) return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed can't be unknown"); if (dSize != 1) dims.push_back(r); } auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, mhloShape); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSqueezeDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto self = adaptor.self(); auto selfTy = self.getType().template cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); auto rank = selfTy.getRank(); if (rank == 0) return rewriter.notifyMatchFailure( op, "the rank of tensor must be greater than 0"); dim = toPositiveDim(dim, rank); if (selfTy.getShape()[dim] != 1) { if (selfTy.getShape()[dim] == ShapedType::kDynamicSize) return rewriter.notifyMatchFailure( op, "the size of the dimension being squeezed is can't be unknown"); rewriter.replaceOp(op, adaptor.self()); return success(); } SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto mhloShape = rewriter.create(op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, mhloShape); return success(); } template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenUnsqueezeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) { return op.emitError("only tensor types are currently supported"); } int64_t dim; if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); rewriter.replaceOp(op, *unsqzTensorInfo); return success(); } } // namespace void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenSqueezeOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN }