mirror of https://github.com/llvm/torch-mlir
[stablehlo] Support PrimsCollapseOp and PrimsSplitDimOp in stablehlo (#3230)
parent
b1e2241479
commit
0a5ff68d9d
|
@ -69,6 +69,17 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
Value tensor, ArrayRef<int64_t> inputUnsqzDims,
|
||||||
size_t dimSizeIndexBits);
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
|
// Get a tensor that collapse the specified dimensions of the input tensor
|
||||||
|
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t collapseStartDim,
|
||||||
|
int64_t collapseEndDim,
|
||||||
|
size_t dimSizeIndexBits);
|
||||||
|
|
||||||
|
// Get a tensor that splits the specified dimensions of the input tensor
|
||||||
|
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t splitDim,
|
||||||
|
int64_t outerLength, size_t dimSizeIndexBits);
|
||||||
|
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType);
|
TensorType outType);
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
||||||
|
@ -306,6 +307,136 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> collapseTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t collapseStartDim,
|
||||||
|
int64_t collapseEndDim,
|
||||||
|
size_t dimSizeIndexBits) {
|
||||||
|
|
||||||
|
auto dimSizesInfo =
|
||||||
|
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(dimSizesInfo))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
|
||||||
|
auto dimSizes = *dimSizesInfo;
|
||||||
|
int64_t rank = dimSizes.size();
|
||||||
|
|
||||||
|
collapseStartDim = toPositiveDim(collapseStartDim, rank);
|
||||||
|
collapseEndDim = toPositiveDim(collapseEndDim, rank);
|
||||||
|
|
||||||
|
int64_t newRank = rank - (collapseEndDim - collapseStartDim + 1);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
|
auto oldShape = rankTy.getShape();
|
||||||
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
|
|
||||||
|
std::vector<Value> newDimSizes;
|
||||||
|
std::vector<int64_t> newShape;
|
||||||
|
newDimSizes.reserve(newRank);
|
||||||
|
newShape.reserve(newRank);
|
||||||
|
|
||||||
|
Value collapseDimSize = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
int64_t collapseShape = 1;
|
||||||
|
|
||||||
|
for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) {
|
||||||
|
if (k < 0 || k >= rank) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "collapse dimensions must be within the rank of the tensor");
|
||||||
|
}
|
||||||
|
if (collapseShape == ShapedType::kDynamic ||
|
||||||
|
oldShape[k] == ShapedType::kDynamic) {
|
||||||
|
collapseShape = ShapedType::kDynamic;
|
||||||
|
} else {
|
||||||
|
collapseShape *= oldShape[k];
|
||||||
|
}
|
||||||
|
collapseDimSize =
|
||||||
|
rewriter.create<arith::MulIOp>(loc, collapseDimSize, dimSizes[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t k = 0; k < collapseStartDim; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
newDimSizes.push_back(collapseDimSize);
|
||||||
|
newShape.push_back(collapseShape);
|
||||||
|
for (int64_t k = collapseEndDim + 1; k < rank; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||||
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||||
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: support splitDim & outerLength to be Value
|
||||||
|
FailureOr<Value> splitTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value tensor, int64_t splitDim,
|
||||||
|
int64_t outerLength, size_t dimSizeIndexBits) {
|
||||||
|
auto dimSizesInfo =
|
||||||
|
getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(dimSizesInfo))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "failed to get dimension sizes of the input");
|
||||||
|
|
||||||
|
auto dimSizes = *dimSizesInfo;
|
||||||
|
int64_t rank = dimSizes.size();
|
||||||
|
splitDim = toPositiveDim(splitDim, rank);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
|
auto oldShape = rankTy.getShape();
|
||||||
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (splitDim < 0 || splitDim >= rank) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "split dimensions must be within the rank of the tensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t newRank = rank + 1;
|
||||||
|
auto outerLengthValue = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(intType, outerLength));
|
||||||
|
|
||||||
|
auto innerLengthValue = rewriter.create<arith::DivSIOp>(
|
||||||
|
loc, dimSizes[splitDim], outerLengthValue);
|
||||||
|
|
||||||
|
int64_t originShape = oldShape[splitDim];
|
||||||
|
int64_t outerShape = outerLength;
|
||||||
|
int64_t innerShape = originShape == ShapedType::kDynamic
|
||||||
|
? ShapedType::kDynamic
|
||||||
|
: originShape / outerLength;
|
||||||
|
|
||||||
|
std::vector<Value> newDimSizes;
|
||||||
|
std::vector<int64_t> newShape;
|
||||||
|
|
||||||
|
newDimSizes.reserve(newRank);
|
||||||
|
newShape.reserve(newRank);
|
||||||
|
|
||||||
|
for (int64_t k = 0; k < splitDim; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
newDimSizes.push_back(outerLengthValue);
|
||||||
|
newShape.push_back(outerShape);
|
||||||
|
newDimSizes.push_back(innerLengthValue);
|
||||||
|
newShape.push_back(innerShape);
|
||||||
|
|
||||||
|
for (int64_t k = splitDim + 1; k < rank; ++k) {
|
||||||
|
newDimSizes.push_back(dimSizes[k]);
|
||||||
|
newShape.push_back(oldShape[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||||
|
auto shape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||||
|
return rewriter.create<stablehlo::DynamicReshapeOp>(loc, outTy, tensor, shape)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
Value getConstantOfShape(PatternRewriter &rewriter, Location loc,
|
||||||
const APFloat &constant, Value shape,
|
const APFloat &constant, Value shape,
|
||||||
TensorType outType) {
|
TensorType outType) {
|
||||||
|
|
|
@ -414,34 +414,44 @@ LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only constant end is currently supported");
|
op, "only constant end is currently supported");
|
||||||
|
|
||||||
start = toPositiveDim(start, rank);
|
auto collapseTensorInfo = hlo::collapseTensor(
|
||||||
end = toPositiveDim(end, rank);
|
rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits);
|
||||||
SmallVector<int64_t, 4> dims;
|
if (failed(collapseTensorInfo))
|
||||||
dims.reserve(rank);
|
return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor");
|
||||||
for (int r = 0; r < start; ++r)
|
|
||||||
dims.push_back(r);
|
|
||||||
int64_t collapsedDimSize = 1;
|
|
||||||
for (int r = start; r <= end; ++r) {
|
|
||||||
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "the size of the dimension being collapsed is can't be unknown");
|
|
||||||
collapsedDimSize *= selfType.getShape()[r];
|
|
||||||
}
|
|
||||||
dims.push_back(collapsedDimSize);
|
|
||||||
for (int r = end + 1; r < rank; ++r)
|
|
||||||
dims.push_back(r);
|
|
||||||
|
|
||||||
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
rewriter.replaceOp(op, *collapseTensorInfo);
|
||||||
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
return success();
|
||||||
if (failed(newDimSizesInfo))
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<PrimsSplitDimOp>::matchAndRewrite(
|
||||||
|
PrimsSplitDimOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType) {
|
||||||
|
return op.emitError("only tensor types are currently supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rank = selfType.getRank();
|
||||||
|
if (rank == 0)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "failed to get dimension sizes of the input");
|
op, "the rank of tensor must be greater than 0");
|
||||||
auto newDimSizes = *newDimSizesInfo;
|
|
||||||
auto stablehloShape =
|
int64_t dim, outerLength;
|
||||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
return rewriter.notifyMatchFailure(
|
||||||
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
op, "only constant dim is currently supported");
|
||||||
stablehloShape);
|
if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only constant outerLength is currently supported");
|
||||||
|
|
||||||
|
auto splitTensorInfo = hlo::splitTensor(
|
||||||
|
rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits);
|
||||||
|
|
||||||
|
if (failed(splitTensorInfo))
|
||||||
|
return rewriter.notifyMatchFailure(op, "failed to create split tensor");
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, *splitTensorInfo);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,6 +468,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
||||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||||
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
||||||
|
INSERT_ATENOP_PATTERN(PrimsSplitDimOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -678,11 +678,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
||||||
"NumToTensorIntModule_basic",
|
"NumToTensorIntModule_basic",
|
||||||
"NumelModule_basic",
|
"NumelModule_basic",
|
||||||
"NumelZeroRankModule_basic",
|
"NumelZeroRankModule_basic",
|
||||||
"PixelShuffleModuleFullDynamic_basic",
|
|
||||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
|
||||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
|
||||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
|
||||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
|
||||||
"PowIntFloatModule_basic",
|
"PowIntFloatModule_basic",
|
||||||
"PrimMaxIntModule_basic",
|
"PrimMaxIntModule_basic",
|
||||||
"PrimMinIntDynamicModule_basic",
|
"PrimMinIntDynamicModule_basic",
|
||||||
|
@ -1157,6 +1152,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"Permute0RankModule_basic",
|
"Permute0RankModule_basic",
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||||
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
"PowIntFloatModule_basic",
|
"PowIntFloatModule_basic",
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
"PrimMaxIntModule_basic",
|
"PrimMaxIntModule_basic",
|
||||||
|
@ -1240,6 +1237,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"SliceWholeTensorModule_basic",
|
"SliceWholeTensorModule_basic",
|
||||||
"SortIntListReverse_basic",
|
"SortIntListReverse_basic",
|
||||||
"SortIntList_basic",
|
"SortIntList_basic",
|
||||||
|
"SplitDimStaticModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
"SplitTensorGetItem_Module_basic",
|
||||||
"SplitTensorLastSmallerModule_basic",
|
"SplitTensorLastSmallerModule_basic",
|
||||||
"SplitTensorListUnpackModule_basic",
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
|
Loading…
Reference in New Issue