[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230)

pull/3261/head
Xinyu Yang 2024-04-29 17:40:30 +08:00 committed by GitHub
parent b1e2241479
commit 0a5ff68d9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 182 additions and 31 deletions

View File

@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, ArrayRef<int64_t> inputUnsqzDims, Value tensor, ArrayRef<int64_t> inputUnsqzDims,
size_t dimSizeIndexBits); size_t dimSizeIndexBits);
// Get a tensor that collapse the specified dimensions of the input tensor
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits);
// Get a tensor that splits the specified dimensions of the input tensor
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits);
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType); TensorType outType);

View File

@ -9,6 +9,7 @@
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
.getResult(); .getResult();
} }
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t collapseStartDim,
int64_t collapseEndDim,
size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
collapseStartDim = toPositiveDim(collapseStartDim, rank);
collapseEndDim = toPositiveDim(collapseEndDim, rank);
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));
int64_t collapseShape = 1;
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
if (k < 0 || k >= rank) {
return rewriter.notifyMatchFailure(
op, "collapse dimensions must be within the rank of the tensor");
}
if (collapseShape == ShapedType::kDynamic ||
oldShape[k] == ShapedType::kDynamic) {
collapseShape = ShapedType::kDynamic;
} else {
collapseShape *= oldShape[k];
}
collapseDimSize =
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
}
for (int64_t k = 0; k < collapseStartDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(collapseDimSize);
newShape.push_back(collapseShape);
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
// TODO: support splitDim & outerLength to be Value
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
Value tensor, int64_t splitDim,
int64_t outerLength, size_t dimSizeIndexBits) {
auto dimSizesInfo =
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
if (failed(dimSizesInfo))
return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input");
auto dimSizes = *dimSizesInfo;
int64_t rank = dimSizes.size();
splitDim = toPositiveDim(splitDim, rank);
auto loc = op->getLoc();
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
auto oldShape = rankTy.getShape();
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
if (splitDim < 0 || splitDim >= rank) {
return rewriter.notifyMatchFailure(
op, "split dimensions must be within the rank of the tensor");
}
int64_t newRank = rank + 1;
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, outerLength));
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
loc, dimSizes[splitDim], outerLengthValue);
int64_t originShape = oldShape[splitDim];
int64_t outerShape = outerLength;
int64_t innerShape = originShape == ShapedType::kDynamic
? ShapedType::kDynamic
: originShape / outerLength;
std::vector<Value> newDimSizes;
std::vector<int64_t> newShape;
newDimSizes.reserve(newRank);
newShape.reserve(newRank);
for (int64_t k = 0; k < splitDim; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
newDimSizes.push_back(outerLengthValue);
newShape.push_back(outerShape);
newDimSizes.push_back(innerLengthValue);
newShape.push_back(innerShape);
for (int64_t k = splitDim + 1; k < rank; ++k) {
newDimSizes.push_back(dimSizes[k]);
newShape.push_back(oldShape[k]);
}
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
.getResult();
}
Value getConstantOfShape(PatternRewriter &rewriter, Location loc, Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
const APFloat &constant, Value shape, const APFloat &constant, Value shape,
TensorType outType) { TensorType outType) {

View File

@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only constant end is currently supported"); op, "only constant end is currently supported");
start = toPositiveDim(start, rank); auto collapseTensorInfo = hlo::collapseTensor(
end = toPositiveDim(end, rank); rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
SmallVector<int64_t, 4> dims; if (failed(collapseTensorInfo))
dims.reserve(rank); return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
for (int r = 0; r < start; ++r)
dims.push_back(r);
int64_t collapsedDimSize = 1;
for (int r = start; r <= end; ++r) {
if (selfType.getShape()[r] == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(
op, "the size of the dimension being collapsed is can't be unknown");
collapsedDimSize *= selfType.getShape()[r];
}
dims.push_back(collapsedDimSize);
for (int r = end + 1; r < rank; ++r)
dims.push_back(r);
auto newDimSizesInfo = hlo::getDimSizesOfTensor( rewriter.replaceOp(op, *collapseTensorInfo);
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits); return success();
if (failed(newDimSizesInfo)) }
template <>
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
PrimsSplitDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
if (!selfType) {
return op.emitError("only tensor types are currently supported");
}
auto rank = selfType.getRank();
if (rank == 0)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "failed to get dimension sizes of the input"); op, "the rank of tensor must be greater than 0");
auto newDimSizes = *newDimSizesInfo;
auto stablehloShape = int64_t dim, outerLength;
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes); if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>( return rewriter.notifyMatchFailure(
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(), op, "only constant dim is currently supported");
stablehloShape); if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
return rewriter.notifyMatchFailure(
op, "only constant outerLength is currently supported");
auto splitTensorInfo = hlo::splitTensor(
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
if (failed(splitTensorInfo))
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
rewriter.replaceOp(op, *splitTensorInfo);
return success(); return success();
} }
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
INSERT_ATENOP_PATTERN(PrimsCollapseOp); INSERT_ATENOP_PATTERN(PrimsCollapseOp);
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_VIEW_OP_PATTERN(AtenOp) \ #define INSERT_VIEW_OP_PATTERN(AtenOp) \

View File

@ -678,11 +678,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"NumToTensorIntModule_basic", "NumToTensorIntModule_basic",
"NumelModule_basic", "NumelModule_basic",
"NumelZeroRankModule_basic", "NumelZeroRankModule_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic", "PowIntFloatModule_basic",
"PrimMaxIntModule_basic", "PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic", "PrimMinIntDynamicModule_basic",
@ -1157,6 +1152,8 @@ STABLEHLO_PASS_SET = {
"Permute0RankModule_basic", "Permute0RankModule_basic",
"PermuteModule_basic", "PermuteModule_basic",
"PermuteNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PowIntFloatModule_basic", "PowIntFloatModule_basic",
"PrimListUnpackNumMismatchModule_basic", "PrimListUnpackNumMismatchModule_basic",
"PrimMaxIntModule_basic", "PrimMaxIntModule_basic",
@ -1240,6 +1237,7 @@ STABLEHLO_PASS_SET = {
"SliceWholeTensorModule_basic", "SliceWholeTensorModule_basic",
"SortIntListReverse_basic", "SortIntListReverse_basic",
"SortIntList_basic", "SortIntList_basic",
"SplitDimStaticModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic", "SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic", "SplitTensorListUnpackModule_basic",