mirror of https://github.com/llvm/torch-mlir
[MHLO] Add [un]squeeze op patterns (#1099)
* [MHLO] Add [un]squeeze op patterns * Conform to llvm coding standard * minor updatepull/1106/head
parent
f424930a28
commit
f50d7013cd
|
@ -36,23 +36,31 @@ static constexpr size_t kMhloDimSizeBits = 64;
|
|||
|
||||
namespace {
|
||||
|
||||
SmallVector<Value, 4> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||
Operation *op, Value value) {
|
||||
SmallVector<size_t> toPositiveDims(ArrayRef<int64_t> dims, int64_t rank) {
|
||||
SmallVector<size_t> posDims;
|
||||
posDims.reserve(rank);
|
||||
std::transform(
|
||||
dims.begin(), dims.end(), std::back_inserter(posDims),
|
||||
[rank](int64_t d) -> size_t { return toPositiveDim(d, rank); });
|
||||
return posDims;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value,
|
||||
ArrayRef<int64_t> inpDims) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy) {
|
||||
op->emitOpError("getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||
return {};
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||
}
|
||||
|
||||
auto rank = valueTy.getRank();
|
||||
if (rank == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto dims = toPositiveDims(inpDims, rank);
|
||||
SmallVector<Value, 4> dimSizes;
|
||||
dimSizes.reserve(rank);
|
||||
dimSizes.reserve(dims.size());
|
||||
|
||||
auto loc = op->getLoc();
|
||||
for (auto d = 0; d < rank; ++d) {
|
||||
for (auto d : dims) {
|
||||
dimSizes.emplace_back(rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIntegerType(kMhloDimSizeBits),
|
||||
rewriter.create<tensor::DimOp>(loc, value, d)));
|
||||
|
@ -60,6 +68,21 @@ SmallVector<Value, 4> getDimSizesOfTensor(PatternRewriter &rewriter,
|
|||
return dimSizes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) {
|
||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
||||
if (!valueTy) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||
}
|
||||
|
||||
auto rank = valueTy.getRank();
|
||||
// Get int vector [0, 1, ..., rank-1]
|
||||
std::vector<int64_t> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
return getDimSizesOfTensor(rewriter, op, value, dims);
|
||||
}
|
||||
|
||||
// A dimension index from torch.dialect might outside the range [0, dimSize].
|
||||
// The function is used to normalize the input index into the range.
|
||||
Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op,
|
||||
|
@ -140,10 +163,11 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
|||
// Get a dynamic slice of the tensor from startIndex to endIndex with stride
|
||||
// step on the specifed dimension. The input startIndex(default to 0),
|
||||
// endIndex(default to dimSize), and step(default to 1) can be optional.
|
||||
Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input,
|
||||
llvm::Optional<Value> startIndexOpt,
|
||||
llvm::Optional<Value> endIndexOpt,
|
||||
llvm::Optional<Value> stepOpt, int64_t dim) {
|
||||
FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||
Value input,
|
||||
llvm::Optional<Value> startIndexOpt,
|
||||
llvm::Optional<Value> endIndexOpt,
|
||||
llvm::Optional<Value> stepOpt, int64_t dim) {
|
||||
auto loc = op->getLoc();
|
||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
||||
auto rank = inputTy.getRank();
|
||||
|
@ -174,8 +198,13 @@ Value getDynamicSlice(PatternRewriter &rewriter, Operation *op, Value input,
|
|||
normEndIndex = rewriter.create<arith::TruncIOp>(loc, i32Type, normEndIndex);
|
||||
step = rewriter.create<arith::TruncIOp>(loc, i32Type, step);
|
||||
#endif
|
||||
auto dimSizes = getDimSizesOfTensor(rewriter, op, input);
|
||||
FailureOr<SmallVector<Value, 4>> dimSizesInfo =
|
||||
getDimSizesOfTensor(rewriter, op, input);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
return getDynamicSliceInternal(rewriter, op, input, normStartIndex,
|
||||
normEndIndex, step, dim, dimSizes);
|
||||
}
|
||||
|
@ -197,11 +226,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("Only ranked tensor types supported in MHLO Rsub");
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only constant dim is currently supported");
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
auto getOptionalVal = [&](Value val) -> llvm::Optional<Value> {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
|
@ -215,9 +244,14 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
llvm::Optional<Value> end = getOptionalVal(adaptor.end());
|
||||
llvm::Optional<Value> step = getOptionalVal(adaptor.step());
|
||||
|
||||
Value sliced = getDynamicSlice(rewriter, op, self, start, end, step, dim);
|
||||
FailureOr<Value> sliceInfo =
|
||||
getDynamicSlice(rewriter, op, self, start, end, step, dim);
|
||||
if (failed(sliceInfo))
|
||||
return op.emitError("can not create a dynmaic slice");
|
||||
|
||||
auto slice = *sliceInfo;
|
||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), sliced);
|
||||
op, getTypeConverter()->convertType(op.getType()), slice);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -316,6 +350,160 @@ bool ConvertAtenViewOp<AtenReshapeOp>::getAtenViewOpSizes(
|
|||
return getListConstructElements(adaptor.shape(), dimSizes);
|
||||
}
|
||||
|
||||
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||
Value tensor,
|
||||
ArrayRef<int64_t> inputUnsqzDims) {
|
||||
// Returns a new tensor with dims of size 1 inserted at the specified
|
||||
// position.
|
||||
//
|
||||
// The position indices (must be high to low dimension number of the returned
|
||||
// tensor) are specified with unsqzDims. Indices must be in-order, and in
|
||||
// range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1,
|
||||
// 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not.
|
||||
auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor);
|
||||
if (failed(dimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
|
||||
auto dimSizes = *dimSizesInfo;
|
||||
auto rank = dimSizes.size();
|
||||
size_t newRank = rank + inputUnsqzDims.size();
|
||||
auto unsqzDims = toPositiveDims(inputUnsqzDims, newRank);
|
||||
for (size_t k = 0, sz = unsqzDims.size(); k < sz; ++k)
|
||||
if (k > 1 && unsqzDims[k] <= unsqzDims[k - 1])
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsqueeze dimensions must be specified in order");
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
||||
auto oldShape = rankTy.getShape();
|
||||
Type intType = rewriter.getIntegerType(kMhloDimSizeBits);
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
||||
std::vector<Value> newDimSizes;
|
||||
std::vector<int64_t> newShape;
|
||||
newDimSizes.reserve(newRank);
|
||||
newShape.reserve(newRank);
|
||||
for (size_t k = 0, i = 0, j = 0; k < newRank; ++k) {
|
||||
if (j < unsqzDims.size() && unsqzDims[j] == k) {
|
||||
newDimSizes.push_back(one);
|
||||
newShape.push_back(1);
|
||||
j++;
|
||||
} else {
|
||||
newDimSizes.push_back(dimSizes[i]);
|
||||
newShape.push_back(oldShape[i]);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
auto outTy = RankedTensorType::get(newShape, rankTy.getElementType());
|
||||
auto mhloShape = rewriter.create<tensor::FromElementsOp>(loc, newDimSizes);
|
||||
return rewriter.create<mhlo::DynamicReshapeOp>(loc, outTy, tensor, mhloShape)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||
AtenSqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
|
||||
auto rank = selfTy.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The rank of tensor must be greater than 0");
|
||||
|
||||
SmallVector<int64_t, 4> dims;
|
||||
dims.reserve(rank);
|
||||
for (int r = 0; r < rank; ++r) {
|
||||
auto dSize = selfTy.getShape()[r];
|
||||
if (dSize == ShapedType::kDynamicSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the size of the dimension being squeezed can't be unknown");
|
||||
if (dSize != 1)
|
||||
dims.push_back(r);
|
||||
}
|
||||
|
||||
auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
||||
AtenSqueezeDimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("only ranked tensor types are supported");
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant dim is currently supported");
|
||||
|
||||
auto rank = selfTy.getRank();
|
||||
if (rank == 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the rank of tensor must be greater than 0");
|
||||
|
||||
dim = toPositiveDim(dim, rank);
|
||||
if (selfTy.getShape()[dim] != 1) {
|
||||
if (selfTy.getShape()[dim] == ShapedType::kDynamicSize)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "the size of the dimension being squeezed is can't be unknown");
|
||||
|
||||
rewriter.replaceOp(op, adaptor.self());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
dims.erase(dims.begin() + dim);
|
||||
auto newDimSizesInfo = getDimSizesOfTensor(rewriter, op, self, dims);
|
||||
if (failed(newDimSizesInfo))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "failed to get dimension sizes of the input");
|
||||
auto newDimSizes = *newDimSizesInfo;
|
||||
auto mhloShape =
|
||||
rewriter.create<tensor::FromElementsOp>(op.getLoc(), newDimSizes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, mhloShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("only tensor types are currently supported");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
|
||||
auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, adaptor.self(), {dim});
|
||||
if (failed(unsqzTensorInfo))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"failed to create unsqueezed tensor");
|
||||
|
||||
rewriter.replaceOp(op, *unsqzTensorInfo);
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
||||
|
@ -327,6 +515,9 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality(
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqueezeDimOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_VIEW_OP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -414,3 +414,168 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt
|
|||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
|
||||
return %1 : !torch.vtensor<[],f32>
|
||||
}
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32>
|
||||
// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32>
|
||||
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32>
|
||||
return %0 : !torch.vtensor<[2,1,2,1,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$1(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64
|
||||
// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x?x1x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x?x1x?xf32> -> !torch.vtensor<[?,?,1,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,?,1,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,?,1,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,1,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$from_end(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,1,?,1,?],f32> -> tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<?x1x?x1x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64
|
||||
// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<4xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor<?x1x?x1x?xf32>, tensor<4xi64>) -> tensor<?x1x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x1x?x?xf32> -> !torch.vtensor<[?,1,?,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?],f32>
|
||||
func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> {
|
||||
%int-2 = torch.constant.int -2
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int-2 : !torch.vtensor<[?,1,?,1,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,1,?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[T7:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xi64>
|
||||
// CHECK: %[[T8:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T7]]) : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32>
|
||||
// CHECK: return %[[T9]] : !torch.vtensor<[2,2,2],f32>
|
||||
func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> {
|
||||
%0 = torch.aten.squeeze %arg0 : !torch.vtensor<[2,1,2,1,2],f32> -> !torch.vtensor<[2,2,2],f32>
|
||||
return %0 : !torch.vtensor<[2,2,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$0(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T9:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<1x?x?x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[1,?,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[1,?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[1,?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.unsqueeze$dim$1(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[T4]], %[[T6]], %[[T8]] : tensor<5xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x1x?x?x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x1x?x?x?xf32> -> !torch.vtensor<[?,1,?,?,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,1,?,?,?],f32>
|
||||
func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,1,?,?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.unsqueeze$from_end(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INT:.*]]-2 = torch.constant.int -2
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[T1:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T2:.*]] = arith.index_cast %[[T1]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T4:.*]] = arith.index_cast %[[T3]] : index to i64
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[T5:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[T7:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64
|
||||
// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64
|
||||
// CHECK: %[[T9:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]], %[[C1_I64]], %[[T8]] : tensor<5xi64>
|
||||
// CHECK: %[[T10:.*]] = "mhlo.dynamic_reshape"(%[[T0]], %[[T9]]) : (tensor<?x?x?x?xf32>, tensor<5xi64>) -> tensor<?x?x?x1x?xf32>
|
||||
// CHECK: %[[T11:.*]] = torch_c.from_builtin_tensor %[[T10]] : tensor<?x?x?x1x?xf32> -> !torch.vtensor<[?,?,?,1,?],f32>
|
||||
// CHECK: return %[[T11]] : !torch.vtensor<[?,?,?,1,?],f32>
|
||||
func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> {
|
||||
%int-2 = torch.constant.int -2
|
||||
%0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,1,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,1,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue