mirror of https://github.com/llvm/torch-mlir
[Torch] enhance fold of aten.slice.Tensor (#3557)
so that it could support folding slice with any static shape.pull/3558/head
parent
78846425e2
commit
21ad890009
|
@ -3625,12 +3625,11 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
return DenseElementsAttr::get(outType.toBuiltinTensor(),
|
||||
input.getSplatValue<Attribute>());
|
||||
|
||||
int count = 1;
|
||||
int64_t count = 1;
|
||||
for (auto dim : outType.getSizes())
|
||||
count = count * dim;
|
||||
|
||||
if (count == 0)
|
||||
return {};
|
||||
return nullptr;
|
||||
|
||||
if (!dim)
|
||||
return nullptr;
|
||||
|
@ -3638,29 +3637,41 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
|||
if (dimInt < 0)
|
||||
dimInt += inType.getSizes().size();
|
||||
|
||||
bool unaryNonDim = true;
|
||||
for (int i = 0, s = outType.getSizes().size(); i < s; ++i)
|
||||
unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt;
|
||||
|
||||
// Fold the slice if the output tensor is relatively small, currently
|
||||
// coded to 16:
|
||||
if (input && start && step && dim && count < 16 && unaryNonDim &&
|
||||
count < 16) {
|
||||
int64_t inCount = input.getNumElements();
|
||||
constexpr int64_t kMaxFold = 16;
|
||||
if (input && start && step && dim && count <= kMaxFold) {
|
||||
int64_t begin = start.getValue().getSExtValue();
|
||||
int64_t limit = end.getValue().getSExtValue();
|
||||
int64_t stride = step.getValue().getSExtValue();
|
||||
if (stride < 1)
|
||||
return {};
|
||||
int64_t limit = end.getValue().getSExtValue();
|
||||
begin = begin < 0 ? begin + inCount : begin;
|
||||
limit = limit < 0 ? limit + inCount : limit;
|
||||
limit = limit < 0 ? inType.getSizes()[dimInt] : limit;
|
||||
return nullptr;
|
||||
begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin;
|
||||
limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit;
|
||||
limit = std::min(limit, inType.getSizes()[dimInt]);
|
||||
|
||||
llvm::SmallVector<Attribute> values;
|
||||
for (int i = begin; i < limit; i += stride)
|
||||
values.push_back(input.getValues<Attribute>()[i]);
|
||||
int64_t inputRank = inType.getSizes().size();
|
||||
llvm::SmallVector<int64_t> inputStrides(inputRank, 1);
|
||||
for (int64_t i = inputRank - 2; i >= 0; i--) {
|
||||
inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1];
|
||||
}
|
||||
|
||||
llvm::SmallVector<Attribute> values;
|
||||
values.reserve(count);
|
||||
auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) {
|
||||
if (currDim >= inputRank)
|
||||
return;
|
||||
size_t _begin = (currDim == dimInt) ? begin : 0;
|
||||
size_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim];
|
||||
size_t _stride = (currDim == dimInt) ? stride : 1;
|
||||
for (size_t i = _begin; i < _limit; i += _stride) {
|
||||
if (currDim == inputRank - 1) {
|
||||
values.push_back(input.getValues<Attribute>()[currOffset + i]);
|
||||
}
|
||||
self(self, currDim + 1, currOffset + inputStrides[currDim] * i);
|
||||
}
|
||||
};
|
||||
recursiveIter(recursiveIter, 0, 0);
|
||||
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
|
||||
}
|
||||
|
||||
|
|
|
@ -2139,15 +2139,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
||||
// CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice
|
||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
||||
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
||||
func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> {
|
||||
// CHECK: torch.aten.slice.Tensor
|
||||
func.func @torch.aten.slice.tensor$not_fold_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%int-1 = torch.constant.int -1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32>
|
||||
return %0 : !torch.vtensor<[4],f32>
|
||||
%0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], f32>
|
||||
return %0 : !torch.vtensor<[3],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice
|
||||
|
@ -2209,7 +2209,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
|
|||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
|
||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
||||
// CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
||||
// CHECK: return %[[CST]], %[[CST0]]
|
||||
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
|
||||
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
|
||||
%int0 = torch.constant.int 0
|
||||
|
@ -2224,6 +2227,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
|
|||
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> {
|
||||
// CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64>
|
||||
// CHECK: return %0
|
||||
func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> (!torch.vtensor<[4,1],si64>) {
|
||||
%int1 = torch.constant.int 1
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.vtensor.literal(dense<[[28, 28], [14, 14], [7, 7], [4, 4]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64>
|
||||
%1 = torch.aten.slice.Tensor %0, %int1, %int1, %int2, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64>
|
||||
return %1 : !torch.vtensor<[4,1],si64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
|
|
Loading…
Reference in New Issue