From 0a5ff68d9d57c9c3948b6d60c1edb32da9fe3670 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 29 Apr 2024 17:40:30 +0800 Subject: [PATCH] [stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230) --- .../TorchToStablehlo/StablehloLegalizeUtils.h | 11 ++ .../StablehloLegalizeUtils.cpp | 131 ++++++++++++++++++ lib/Conversion/TorchToStablehlo/ViewLike.cpp | 63 +++++---- projects/pt1/e2e_testing/xfail_sets.py | 8 +- 4 files changed, 182 insertions(+), 31 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 6e14b324b..734ba81ea 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -69,6 +69,17 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef inputUnsqzDims, size_t dimSizeIndexBits); +// Get a tensor that collapse the specified dimensions of the input tensor +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits); + +// Get a tensor that splits the specified dimensions of the input tensor +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits); + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 40ec715cd..c4d629d4f 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" @@ -306,6 +307,136 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, .getResult(); } +FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t collapseStartDim, + int64_t collapseEndDim, + size_t dimSizeIndexBits) { + + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + + collapseStartDim = toPositiveDim(collapseStartDim, rank); + collapseEndDim = toPositiveDim(collapseEndDim, rank); + + int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + std::vector newDimSizes; + std::vector newShape; + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + Value collapseDimSize = rewriter.create( + loc, rewriter.getIntegerAttr(intType, 1)); + int64_t collapseShape = 1; + + for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { + if (k < 0 || k >= rank) { + return rewriter.notifyMatchFailure( + op, "collapse dimensions must be within the rank of the tensor"); + } + if (collapseShape == ShapedType::kDynamic || + oldShape[k] == ShapedType::kDynamic) { + collapseShape = ShapedType::kDynamic; + } else { + collapseShape *= oldShape[k]; + } + collapseDimSize = + rewriter.create(loc, collapseDimSize, dimSizes[k]); + } + + for (int64_t k = 0; k < collapseStartDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(collapseDimSize); + newShape.push_back(collapseShape); + for (int64_t k = collapseEndDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + +// TODO: support splitDim & outerLength to be Value +FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, + Value tensor, int64_t splitDim, + int64_t outerLength, size_t dimSizeIndexBits) { + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + + if (failed(dimSizesInfo)) + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + + auto dimSizes = *dimSizesInfo; + int64_t rank = dimSizes.size(); + splitDim = toPositiveDim(splitDim, rank); + + auto loc = op->getLoc(); + auto rankTy = dyn_cast(tensor.getType()); + auto oldShape = rankTy.getShape(); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + + if (splitDim < 0 || splitDim >= rank) { + return rewriter.notifyMatchFailure( + op, "split dimensions must be within the rank of the tensor"); + } + + int64_t newRank = rank + 1; + auto outerLengthValue = rewriter.create( + loc, rewriter.getIntegerAttr(intType, outerLength)); + + auto innerLengthValue = rewriter.create( + loc, dimSizes[splitDim], outerLengthValue); + + int64_t originShape = oldShape[splitDim]; + int64_t outerShape = outerLength; + int64_t innerShape = originShape == ShapedType::kDynamic + ? ShapedType::kDynamic + : originShape / outerLength; + + std::vector newDimSizes; + std::vector newShape; + + newDimSizes.reserve(newRank); + newShape.reserve(newRank); + + for (int64_t k = 0; k < splitDim; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + newDimSizes.push_back(outerLengthValue); + newShape.push_back(outerShape); + newDimSizes.push_back(innerLengthValue); + newShape.push_back(innerShape); + + for (int64_t k = splitDim + 1; k < rank; ++k) { + newDimSizes.push_back(dimSizes[k]); + newShape.push_back(oldShape[k]); + } + + auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) + .getResult(); +} + Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index e43105ea1..04952d843 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant end is currently supported"); - start = toPositiveDim(start, rank); - end = toPositiveDim(end, rank); - SmallVector dims; - dims.reserve(rank); - for (int r = 0; r < start; ++r) - dims.push_back(r); - int64_t collapsedDimSize = 1; - for (int r = start; r <= end; ++r) { - if (selfType.getShape()[r] == ShapedType::kDynamic) - return rewriter.notifyMatchFailure( - op, "the size of the dimension being collapsed is can't be unknown"); - collapsedDimSize *= selfType.getShape()[r]; - } - dims.push_back(collapsedDimSize); - for (int r = end + 1; r < rank; ++r) - dims.push_back(r); + auto collapseTensorInfo = hlo::collapseTensor( + rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits); + if (failed(collapseTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor"); - auto newDimSizesInfo = hlo::getDimSizesOfTensor( - rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); - if (failed(newDimSizesInfo)) + rewriter.replaceOp(op, *collapseTensorInfo); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto selfType = adaptor.getA().getType().dyn_cast(); + if (!selfType) { + return op.emitError("only tensor types are currently supported"); + } + + auto rank = selfType.getRank(); + if (rank == 0) return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - auto newDimSizes = *newDimSizesInfo; - auto stablehloShape = - rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), - stablehloShape); + op, "the rank of tensor must be greater than 0"); + + int64_t dim, outerLength; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) + return rewriter.notifyMatchFailure( + op, "only constant outerLength is currently supported"); + + auto splitTensorInfo = hlo::splitTensor( + rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits); + + if (failed(splitTensorInfo)) + return rewriter.notifyMatchFailure(op, "failed to create split tensor"); + + rewriter.replaceOp(op, *splitTensorInfo); return success(); } @@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); #undef INSERT_ATENOP_PATTERN #define INSERT_VIEW_OP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e45839617..10c24b657 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -678,11 +678,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PixelShuffleModuleFullDynamic_basic", - "PixelShuffleModuleSpatiallyDynamic_basic", - "PixelShuffleModuleSpatiallyStatic_basic", - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", @@ -1157,6 +1152,8 @@ STABLEHLO_PASS_SET = { "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", "PowIntFloatModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimMaxIntModule_basic", @@ -1240,6 +1237,7 @@ STABLEHLO_PASS_SET = { "SliceWholeTensorModule_basic", "SortIntListReverse_basic", "SortIntList_basic", + "SplitDimStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic",