mirror of https://github.com/llvm/torch-mlir
[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230)
parent
b1e2241479
commit
0a5ff68d9d
|
@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get a tensor that collapse the specified dimensions of the input tensor
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits);
|
||||
|
||||
// Get a tensor that splits the specified dimensions of the input tensor
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits);
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType);
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "stablehlo/dialect/StablehloOps.h"
|
||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||
|
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
|||
.getResult();
|
||||
}
|
||||
|
||||
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t collapseStartDim,
|
||||
int64_t collapseEndDim,
|
||||
size_t dimSizeIndexBits) {
|
||||
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
int64_t rank = dimSizes.size();
|
||||
|
||||
collapseStartDim = toPositiveDim(collapseStartDim, rank);
|
||||
collapseEndDim = toPositiveDim(collapseEndDim, rank);
|
||||
|
||||
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
newDimSizes.reserve(newRank);
|
||||
newShape.reserve(newRank);
|
||||
|
||||
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
int64_t collapseShape = 1;
|
||||
|
||||
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
||||
if (k < 0 || k >= rank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "collapse dimensions must be within the rank of the tensor");
|
||||
}
|
||||
if (collapseShape == ShapedType::kDynamic ||
|
||||
oldShape[k] == ShapedType::kDynamic) {
|
||||
collapseShape = ShapedType::kDynamic;
|
||||
} else {
|
||||
collapseShape *= oldShape[k];
|
||||
}
|
||||
collapseDimSize =
|
||||
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < collapseStartDim; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
newDimSizes.push_back(collapseDimSize);
|
||||
newShape.push_back(collapseShape);
|
||||
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
// TODO: support splitDim & outerLength to be Value
|
||||
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor, int64_t splitDim,
|
||||
int64_t outerLength, size_t dimSizeIndexBits) {
|
||||
auto dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
int64_t rank = dimSizes.size();
|
||||
splitDim = toPositiveDim(splitDim, rank);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||
|
||||
if (splitDim < 0 || splitDim >= rank) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split dimensions must be within the rank of the tensor");
|
||||
}
|
||||
|
||||
int64_t newRank = rank + 1;
|
||||
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, outerLength));
|
||||
|
||||
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
||||
loc, dimSizes[splitDim], outerLengthValue);
|
||||
|
||||
int64_t originShape = oldShape[splitDim];
|
||||
int64_t outerShape = outerLength;
|
||||
int64_t innerShape = originShape == ShapedType::kDynamic
|
||||
? ShapedType::kDynamic
|
||||
: originShape / outerLength;
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
|
||||
newDimSizes.reserve(newRank);
|
||||
newShape.reserve(newRank);
|
||||
|
||||
for (int64_t k = 0; k < splitDim; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
newDimSizes.push_back(outerLengthValue);
|
||||
newShape.push_back(outerShape);
|
||||
newDimSizes.push_back(innerLengthValue);
|
||||
newShape.push_back(innerShape);
|
||||
|
||||
for (int64_t k = splitDim + 1; k < rank; ++k) {
|
||||
newDimSizes.push_back(dimSizes[k]);
|
||||
newShape.push_back(oldShape[k]);
|
||||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||
const APFloat &constant, Value shape,
|
||||
TensorType outType) {
|
||||
|
|
|
@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant end is currently supported");
|
||||
|
||||
start = toPositiveDim(start, rank);
|
||||
end = toPositiveDim(end, rank);
|
||||
SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(rank);
|
||||
for (int r = 0; r < start; ++r)
|
||||
dims.push_back(r);
|
||||
int64_t collapsedDimSize = 1;
|
||||
for (int r = start; r <= end; ++r) {
|
||||
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the size of the dimension being collapsed is can't be unknown");
|
||||
collapsedDimSize *= selfType.getShape()[r];
|
||||
}
|
||||
dims.push_back(collapsedDimSize);
|
||||
for (int r = end + 1; r < rank; ++r)
|
||||
dims.push_back(r);
|
||||
auto collapseTensorInfo = hlo::collapseTensor(
|
||||
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
|
||||
if (failed(collapseTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
|
||||
|
||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
||||
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
||||
if (failed(newDimSizesInfo))
|
||||
rewriter.replaceOp(op, *collapseTensorInfo);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
||||
auto rank = selfType.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto stablehloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
||||
stablehloShape);
|
||||
op, "the rank of tensor must be greater than 0");
|
||||
|
||||
int64_t dim, outerLength;
|
||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant outerLength is currently supported");
|
||||
|
||||
auto splitTensorInfo = hlo::splitTensor(
|
||||
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
|
||||
|
||||
if (failed(splitTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
|
||||
|
||||
rewriter.replaceOp(op, *splitTensorInfo);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -678,11 +678,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"NumToTensorIntModule_basic",
|
||||
"NumelModule_basic",
|
||||
"NumelZeroRankModule_basic",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
"PrimMinIntDynamicModule_basic",
|
||||
|
@ -1157,6 +1152,8 @@ STABLEHLO_PASS_SET = {
|
|||
"Permute0RankModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PowIntFloatModule_basic",
|
||||
"PrimListUnpackNumMismatchModule_basic",
|
||||
"PrimMaxIntModule_basic",
|
||||
|
@ -1240,6 +1237,7 @@ STABLEHLO_PASS_SET = {
|
|||
"SliceWholeTensorModule_basic",
|
||||
"SortIntListReverse_basic",
|
||||
"SortIntList_basic",
|
||||
"SplitDimStaticModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
|
|
Loading…
Reference in New Issue