2022-07-22 11:32:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
//
|
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
// Also available under a BSD-style license. See LICENSE.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
|
2022-07-22 11:32:45 +08:00
|
|
|
|
|
|
|
#include "../PassDetail.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "PopulatePatterns.h"
|
|
|
|
|
2022-10-05 21:28:06 +08:00
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
2024-04-09 14:54:57 +08:00
|
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
2022-07-22 11:32:45 +08:00
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
2023-02-02 21:29:47 +08:00
|
|
|
#include "stablehlo/dialect/StablehloOps.h"
|
2024-01-30 01:59:33 +08:00
|
|
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
2022-07-22 11:32:45 +08:00
|
|
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
2022-07-22 15:18:18 +08:00
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
|
|
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
|
|
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
2022-07-22 11:32:45 +08:00
|
|
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
2022-07-22 15:18:18 +08:00
|
|
|
#include <numeric>
|
2022-07-22 11:32:45 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::torch;
|
|
|
|
using namespace mlir::torch::Torch;
|
2022-07-22 15:18:18 +08:00
|
|
|
using namespace mlir::torch::TorchConversion;
|
2023-02-02 21:29:47 +08:00
|
|
|
using namespace mlir::torch::torch_to_stablehlo;
|
2022-07-22 11:32:45 +08:00
|
|
|
|
|
|
|
namespace {
|
|
|
|
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
|
|
|
// The function is used to normalize the input index into the range.
|
2022-07-22 15:18:18 +08:00
|
|
|
Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
|
|
|
|
Value index, Value dimSize) {
|
2022-07-22 11:32:45 +08:00
|
|
|
auto loc = op->getLoc();
|
|
|
|
Value zero = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
|
|
|
|
|
|
|
|
// To normalize index into range [-dimSize, dimSize]
|
|
|
|
// index = min(max(-dimSize, index), dimSize)
|
|
|
|
auto negDimSize = rewriter.create<arith::SubIOp>(loc, zero, dimSize);
|
|
|
|
index = rewriter.create<arith::MaxSIOp>(loc, negDimSize, index);
|
|
|
|
index = rewriter.create<arith::MinSIOp>(loc, dimSize, index);
|
|
|
|
|
|
|
|
auto dimSizePlusIndex = rewriter.create<arith::AddIOp>(loc, dimSize, index);
|
|
|
|
auto indexPositive = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::sge, index, zero);
|
|
|
|
// get positive index: (index >=0) ? index: index + dimSize
|
2022-07-22 15:18:18 +08:00
|
|
|
return rewriter.create<arith::SelectOp>(loc, indexPositive, index,
|
|
|
|
dimSizePlusIndex);
|
2022-07-22 11:32:45 +08:00
|
|
|
}
|
|
|
|
|
2022-07-22 15:18:18 +08:00
|
|
|
Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
2022-08-19 10:14:57 +08:00
|
|
|
Type outTy, Value input, Value startIndex,
|
|
|
|
Value endIndex, Value step, size_t dimIndex,
|
2022-09-01 10:36:02 +08:00
|
|
|
ArrayRef<Value> dimSizes,
|
|
|
|
size_t dimSizeIndexBits) {
|
2022-07-22 11:32:45 +08:00
|
|
|
auto loc = op->getLoc();
|
|
|
|
// startIndex & endIndex has been normailized into range [0, dSize]
|
2022-09-01 10:36:02 +08:00
|
|
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
2022-07-22 11:32:45 +08:00
|
|
|
Value zero = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 0));
|
|
|
|
Value one = rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(intType, 1));
|
|
|
|
|
|
|
|
SmallVector<Value, 4> startIndices;
|
|
|
|
SmallVector<Value, 4> endIndices;
|
|
|
|
SmallVector<Value, 4> strides;
|
|
|
|
|
|
|
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
size_t rank = inputTy.getRank();
|
|
|
|
startIndices.reserve(rank);
|
|
|
|
endIndices.reserve(rank);
|
|
|
|
strides.reserve(rank);
|
|
|
|
|
|
|
|
auto endIndexIsZero = rewriter.create<arith::CmpIOp>(
|
|
|
|
loc, arith::CmpIPredicate::eq, endIndex, zero);
|
2022-07-22 15:18:18 +08:00
|
|
|
endIndex = rewriter.create<arith::SelectOp>(loc, endIndexIsZero,
|
|
|
|
dimSizes[dimIndex], endIndex);
|
2022-07-22 11:32:45 +08:00
|
|
|
|
|
|
|
for (size_t r = 0; r < rank; ++r) {
|
|
|
|
if (r == dimIndex) {
|
|
|
|
startIndices.push_back(startIndex);
|
|
|
|
endIndices.push_back(endIndex);
|
|
|
|
strides.push_back(step);
|
|
|
|
} else {
|
|
|
|
startIndices.push_back(zero);
|
|
|
|
endIndices.push_back(dimSizes[r]);
|
|
|
|
strides.push_back(one);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
auto startTensor =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, startIndices).getResult();
|
|
|
|
auto endTensor =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, endIndices).getResult();
|
|
|
|
auto stridesTensor =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, strides).getResult();
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
return rewriter.create<stablehlo::RealDynamicSliceOp>(
|
2022-08-19 10:14:57 +08:00
|
|
|
loc, outTy, input, startTensor, endTensor, stridesTensor);
|
2022-07-22 11:32:45 +08:00
|
|
|
}
|
|
|
|
|
2022-07-22 15:18:18 +08:00
|
|
|
// Get a dynamic slice of the tensor from startIndex to endIndex with stride
|
|
|
|
// step on the specifed dimension. The input startIndex(default to 0),
|
2022-07-22 11:32:45 +08:00
|
|
|
// endIndex(default to dimSize), and step(default to 1) can be optional.
|
2022-07-25 23:28:48 +08:00
|
|
|
FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
2022-08-19 10:14:57 +08:00
|
|
|
Type outTy, Value input,
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<Value> startIndexOpt,
|
|
|
|
std::optional<Value> endIndexOpt,
|
|
|
|
std::optional<Value> stepOpt, int64_t dim,
|
2022-09-01 10:36:02 +08:00
|
|
|
size_t dimSizeIndexBits) {
|
2022-07-22 11:32:45 +08:00
|
|
|
auto loc = op->getLoc();
|
|
|
|
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
|
|
|
auto rank = inputTy.getRank();
|
|
|
|
|
|
|
|
dim = (dim + rank) % rank;
|
|
|
|
Value dimSize = rewriter.create<arith::IndexCastOp>(
|
2022-07-22 15:18:18 +08:00
|
|
|
loc, rewriter.getI64Type(),
|
2022-07-22 11:32:45 +08:00
|
|
|
rewriter.create<tensor::DimOp>(loc, input, dim));
|
|
|
|
|
2022-07-22 15:18:18 +08:00
|
|
|
Value normStartIndex =
|
|
|
|
startIndexOpt
|
|
|
|
? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize)
|
|
|
|
: rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0));
|
|
|
|
Value normEndIndex =
|
|
|
|
endIndexOpt
|
|
|
|
? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize)
|
|
|
|
: dimSize;
|
|
|
|
Value step =
|
|
|
|
stepOpt ? *stepOpt
|
|
|
|
: rewriter.create<arith::ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1));
|
2022-07-22 11:32:45 +08:00
|
|
|
|
2022-09-01 10:36:02 +08:00
|
|
|
if (dimSizeIndexBits == 32) {
|
|
|
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
|
|
|
normStartIndex =
|
|
|
|
rewriter.create<arith::TruncIOp>(loc, intType, normStartIndex);
|
|
|
|
normEndIndex = rewriter.create<arith::TruncIOp>(loc, intType, normEndIndex);
|
|
|
|
step = rewriter.create<arith::TruncIOp>(loc, intType, step);
|
|
|
|
}
|
2022-07-25 23:28:48 +08:00
|
|
|
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
2023-02-02 21:29:47 +08:00
|
|
|
hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits);
|
2022-07-25 23:28:48 +08:00
|
|
|
if (failed(dimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
2022-07-22 11:32:45 +08:00
|
|
|
|
2022-07-25 23:28:48 +08:00
|
|
|
auto dimSizes = *dimSizesInfo;
|
2022-08-19 10:14:57 +08:00
|
|
|
return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex,
|
2022-09-01 10:36:02 +08:00
|
|
|
normEndIndex, step, dim, dimSizes,
|
|
|
|
dimSizeIndexBits);
|
2022-07-22 11:32:45 +08:00
|
|
|
}
|
2022-07-22 15:18:18 +08:00
|
|
|
|
|
|
|
// This defines a template to construct ops whose legalizations are
|
|
|
|
// specialized.
|
|
|
|
template <typename AtenOpT>
|
2022-09-01 10:36:02 +08:00
|
|
|
class ConvertAtenViewOp : public ConvertAtenOp<AtenOpT> {
|
2022-08-19 10:14:57 +08:00
|
|
|
public:
|
2022-09-01 10:36:02 +08:00
|
|
|
using ConvertAtenOp<AtenOpT>::ConvertAtenOp;
|
2022-07-22 15:18:18 +08:00
|
|
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
|
|
|
|
2022-08-19 10:14:57 +08:00
|
|
|
LogicalResult
|
|
|
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const override {
|
2022-07-22 15:18:18 +08:00
|
|
|
auto rankType =
|
2022-12-08 04:20:41 +08:00
|
|
|
adaptor.getSelf().getType().template dyn_cast<RankedTensorType>();
|
2022-07-22 15:18:18 +08:00
|
|
|
if (!rankType)
|
|
|
|
return op.emitError("Only ranked tensor types are currently supported");
|
|
|
|
|
|
|
|
SmallVector<Value, 4> dimSizes;
|
|
|
|
if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) {
|
|
|
|
return op.emitError("Dims size must be a list of Scalar");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto loc = op.getLoc();
|
2024-04-09 14:54:57 +08:00
|
|
|
if (dimSizes.size() == 0 || rankType.getRank() == 0) {
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
2022-07-22 15:18:18 +08:00
|
|
|
op,
|
|
|
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()),
|
2022-12-08 04:20:41 +08:00
|
|
|
adaptor.getSelf());
|
2022-07-22 15:18:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-08-19 10:14:57 +08:00
|
|
|
std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) {
|
2022-07-22 15:18:18 +08:00
|
|
|
dSize = rewriter.create<ToI64Op>(loc, dSize).getResult();
|
|
|
|
return dSize;
|
|
|
|
});
|
|
|
|
|
2024-04-09 14:54:57 +08:00
|
|
|
Value numel = rewriter.create<shape::NumElementsOp>(
|
|
|
|
loc, rewriter.create<shape::ShapeOfOp>(loc, adaptor.getSelf()));
|
2022-07-22 15:18:18 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
Value stablehloShape =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
|
|
|
Value computedShape = rewriter.create<stablehlo::ComputeReshapeShapeOp>(
|
|
|
|
loc, stablehloShape.getType(), numel, stablehloShape);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
2022-07-22 15:18:18 +08:00
|
|
|
op,
|
|
|
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
|
|
op.getType()),
|
2022-12-08 04:20:41 +08:00
|
|
|
adaptor.getSelf(), computedShape);
|
2022-07-22 15:18:18 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2022-08-19 10:14:57 +08:00
|
|
|
bool getAtenViewOpSizes(AtenOpT op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter,
|
|
|
|
SmallVector<Value, 4> &dimSizes) const;
|
2022-07-22 15:18:18 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool ConvertAtenViewOp<AtenViewOp>::getAtenViewOpSizes(
|
2022-08-19 10:14:57 +08:00
|
|
|
AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
|
|
|
SmallVector<Value, 4> &dimSizes) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getListConstructElements(adaptor.getSize(), dimSizes);
|
2022-07-22 15:18:18 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
|
2022-08-19 10:14:57 +08:00
|
|
|
AtenReshapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
|
|
|
|
SmallVector<Value, 4> &dimSizes) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
return getListConstructElements(adaptor.getShape(), dimSizes);
|
2022-07-22 15:18:18 +08:00
|
|
|
}
|
2022-09-01 10:36:02 +08:00
|
|
|
} // namespace
|
|
|
|
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|
|
|
AtenSliceTensorOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto self = adaptor.getSelf();
|
|
|
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
2022-09-01 10:36:02 +08:00
|
|
|
if (!selfTy)
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
auto outTy =
|
|
|
|
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
2022-09-01 10:36:02 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant dim is currently supported");
|
2023-04-07 19:49:35 +08:00
|
|
|
int64_t inputRank = selfTy.getRank();
|
|
|
|
dim = toPositiveDim(dim, inputRank);
|
|
|
|
if (!isValidDim(dim, inputRank))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-09-01 10:36:02 +08:00
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
auto getOptionalVal = [&](Value val) -> std::optional<Value> {
|
2022-09-01 10:36:02 +08:00
|
|
|
if (val.getType().isa<Torch::NoneType>()) {
|
2022-12-14 18:44:05 +08:00
|
|
|
return std::nullopt;
|
2022-09-01 10:36:02 +08:00
|
|
|
} else {
|
|
|
|
return val;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-12-20 18:17:27 +08:00
|
|
|
std::optional<Value> start = getOptionalVal(adaptor.getStart());
|
|
|
|
std::optional<Value> end = getOptionalVal(adaptor.getEnd());
|
|
|
|
std::optional<Value> step = getOptionalVal(adaptor.getStep());
|
2022-09-01 10:36:02 +08:00
|
|
|
|
|
|
|
FailureOr<Value> sliceInfo =
|
|
|
|
getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim,
|
|
|
|
options.dimSizeIndexBits);
|
|
|
|
if (failed(sliceInfo))
|
|
|
|
return op.emitError("can not create a dynmaic slice");
|
|
|
|
|
|
|
|
auto slice = *sliceInfo;
|
|
|
|
rewriter.replaceOp(op, slice);
|
|
|
|
return success();
|
|
|
|
}
|
2022-07-22 15:18:18 +08:00
|
|
|
|
2022-07-25 23:28:48 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
|
|
|
AtenSqueezeOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto self = adaptor.getSelf();
|
|
|
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
2022-07-25 23:28:48 +08:00
|
|
|
if (!selfTy)
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
|
|
|
auto rank = selfTy.getRank();
|
|
|
|
if (rank == 0)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "The rank of tensor must be greater than 0");
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
dims.reserve(rank);
|
|
|
|
for (int r = 0; r < rank; ++r) {
|
|
|
|
auto dSize = selfTy.getShape()[r];
|
2022-12-02 12:38:28 +08:00
|
|
|
if (dSize == ShapedType::kDynamic)
|
2022-07-25 23:28:48 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "the size of the dimension being squeezed can't be unknown");
|
|
|
|
if (dSize != 1)
|
|
|
|
dims.push_back(r);
|
|
|
|
}
|
2022-08-23 16:47:21 +08:00
|
|
|
if (dims.size() == 0) {
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
2022-08-23 16:47:21 +08:00
|
|
|
op, getTypeConverter()->convertType(op.getType()), self);
|
|
|
|
return success();
|
|
|
|
}
|
2022-07-25 23:28:48 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
|
|
|
options.dimSizeIndexBits);
|
2022-07-25 23:28:48 +08:00
|
|
|
if (failed(newDimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
auto newDimSizes = *newDimSizesInfo;
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloShape =
|
2022-07-25 23:28:48 +08:00
|
|
|
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
2022-07-25 23:28:48 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
|
|
|
AtenSqueezeDimOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto self = adaptor.getSelf();
|
|
|
|
auto selfTy = self.getType().cast<RankedTensorType>();
|
2022-07-25 23:28:48 +08:00
|
|
|
if (!selfTy)
|
|
|
|
return op.emitError("only ranked tensor types are supported");
|
|
|
|
|
|
|
|
auto rank = selfTy.getRank();
|
|
|
|
if (rank == 0)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "the rank of tensor must be greater than 0");
|
|
|
|
|
2023-04-07 19:49:35 +08:00
|
|
|
int64_t dim;
|
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant dim is currently supported");
|
2022-07-25 23:28:48 +08:00
|
|
|
dim = toPositiveDim(dim, rank);
|
2023-04-07 19:49:35 +08:00
|
|
|
if (!isValidDim(dim, rank))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
|
|
|
|
2022-07-25 23:28:48 +08:00
|
|
|
if (selfTy.getShape()[dim] != 1) {
|
2022-12-02 12:38:28 +08:00
|
|
|
if (selfTy.getShape()[dim] == ShapedType::kDynamic)
|
2022-07-25 23:28:48 +08:00
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "the size of the dimension being squeezed is can't be unknown");
|
|
|
|
|
2022-12-08 04:20:41 +08:00
|
|
|
rewriter.replaceOp(op, adaptor.getSelf());
|
2022-07-25 23:28:48 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims(rank);
|
|
|
|
std::iota(dims.begin(), dims.end(), 0);
|
|
|
|
dims.erase(dims.begin() + dim);
|
2022-08-23 16:47:21 +08:00
|
|
|
if (dims.size() == 0) {
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::ReshapeOp>(
|
2022-08-23 16:47:21 +08:00
|
|
|
op, getTypeConverter()->convertType(op.getType()), self);
|
|
|
|
return success();
|
|
|
|
}
|
2023-02-02 21:29:47 +08:00
|
|
|
auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims,
|
|
|
|
options.dimSizeIndexBits);
|
2022-07-25 23:28:48 +08:00
|
|
|
if (failed(newDimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
auto newDimSizes = *newDimSizesInfo;
|
2023-02-02 21:29:47 +08:00
|
|
|
auto stablehloShape =
|
2022-07-25 23:28:48 +08:00
|
|
|
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
2023-02-02 21:29:47 +08:00
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), self, stablehloShape);
|
2022-07-25 23:28:48 +08:00
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
|
|
|
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
2022-12-08 04:20:41 +08:00
|
|
|
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
2022-07-25 23:28:48 +08:00
|
|
|
if (!selfType) {
|
|
|
|
return op.emitError("only tensor types are currently supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t dim;
|
2022-12-08 04:20:41 +08:00
|
|
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
2022-07-25 23:28:48 +08:00
|
|
|
return op->emitError("dim must be a Scalar constant");
|
2024-01-30 01:59:33 +08:00
|
|
|
int64_t inputRank =
|
|
|
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
2023-04-07 19:49:35 +08:00
|
|
|
dim = toPositiveDim(dim, inputRank + 1);
|
|
|
|
if (!isValidDim(dim, inputRank + 1))
|
|
|
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
2022-07-25 23:28:48 +08:00
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(),
|
|
|
|
{dim}, options.dimSizeIndexBits);
|
2022-07-25 23:28:48 +08:00
|
|
|
if (failed(unsqzTensorInfo))
|
|
|
|
return rewriter.notifyMatchFailure(op,
|
|
|
|
"failed to create unsqueezed tensor");
|
|
|
|
|
|
|
|
rewriter.replaceOp(op, *unsqzTensorInfo);
|
|
|
|
return success();
|
|
|
|
}
|
2022-07-22 11:32:45 +08:00
|
|
|
|
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
|
|
|
template <>
|
|
|
|
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
|
|
|
PrimsCollapseOp op, OpAdaptor adaptor,
|
|
|
|
ConversionPatternRewriter &rewriter) const {
|
|
|
|
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
|
|
|
if (!selfType) {
|
|
|
|
return op.emitError("only tensor types are currently supported");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto rank = selfType.getRank();
|
|
|
|
if (rank == 0)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "the rank of tensor must be greater than 0");
|
|
|
|
|
|
|
|
int64_t start, end;
|
|
|
|
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant start is currently supported");
|
|
|
|
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "only constant end is currently supported");
|
|
|
|
|
|
|
|
start = toPositiveDim(start, rank);
|
|
|
|
end = toPositiveDim(end, rank);
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
dims.reserve(rank);
|
|
|
|
for (int r = 0; r < start; ++r)
|
|
|
|
dims.push_back(r);
|
|
|
|
int64_t collapsedDimSize = 1;
|
|
|
|
for (int r = start; r <= end; ++r) {
|
|
|
|
if (selfType.getShape()[r] == ShapedType::kDynamic)
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "the size of the dimension being collapsed is can't be unknown");
|
|
|
|
collapsedDimSize *= selfType.getShape()[r];
|
|
|
|
}
|
|
|
|
dims.push_back(collapsedDimSize);
|
|
|
|
for (int r = end + 1; r < rank; ++r)
|
|
|
|
dims.push_back(r);
|
|
|
|
|
|
|
|
auto newDimSizesInfo = hlo::getDimSizesOfTensor(
|
|
|
|
rewriter, op, adaptor.getA(), dims, options.dimSizeIndexBits);
|
|
|
|
if (failed(newDimSizesInfo))
|
|
|
|
return rewriter.notifyMatchFailure(
|
|
|
|
op, "failed to get dimension sizes of the input");
|
|
|
|
auto newDimSizes = *newDimSizesInfo;
|
|
|
|
auto stablehloShape =
|
|
|
|
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
|
|
|
rewriter.replaceOpWithNewOp<stablehlo::DynamicReshapeOp>(
|
|
|
|
op, getTypeConverter()->convertType(op.getType()), adaptor.getA(),
|
|
|
|
stablehloShape);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2023-02-02 21:29:47 +08:00
|
|
|
void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
|
2022-07-22 11:32:45 +08:00
|
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
2023-02-02 21:29:47 +08:00
|
|
|
ConversionTarget &target, const TorchToStablehloOptions &options) {
|
2022-07-22 11:32:45 +08:00
|
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
|
|
|
|
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
2022-07-22 11:32:45 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
2022-07-25 23:28:48 +08:00
|
|
|
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
|
|
|
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
|
|
|
INSERT_ATENOP_PATTERN(PrimsCollapseOp);
|
2022-07-22 11:32:45 +08:00
|
|
|
#undef INSERT_ATENOP_PATTERN
|
|
|
|
|
2022-08-19 10:14:57 +08:00
|
|
|
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
|
|
|
target.addIllegalOp<AtenOp>(); \
|
2022-09-01 10:36:02 +08:00
|
|
|
patterns.add<ConvertAtenViewOp<AtenOp>>(typeConverter, context, options)
|
2022-08-19 10:14:57 +08:00
|
|
|
INSERT_VIEW_OP_PATTERN(AtenViewOp);
|
|
|
|
INSERT_VIEW_OP_PATTERN(AtenReshapeOp);
|
2022-07-22 15:18:18 +08:00
|
|
|
#undef INSERT_VIEW_OP_PATTERN
|
2022-07-22 11:32:45 +08:00
|
|
|
}
|