[torch] Additional folders for shape computations (#2972)

A handful of operations are commonly used in shape calculations (slice,
concat, broadcast). Added these additional folders to better propagate
simple shape computations.
pull/2978/head
Rob Suderman 2024-03-04 11:46:49 -08:00 committed by GitHub
parent 09875fabd1
commit a86e89ecb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 160 additions and 44 deletions

View File

@ -2899,12 +2899,59 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
// We set a maximum folding size of 16. This is a reasonable upper limit
// for shape computations.
constexpr int64_t kMaxFoldSize = 16;
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
if (!list || list.getElements().size() != 1)
if (!list)
return nullptr;
if (list.getElements()[0].getType() != getResult().getType())
auto elements = list.getElements();
if (elements.size() == 1 && elements[0].getType() == getResult().getType())
return list.getElements()[0];
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
return nullptr;
return list.getElements()[0];
auto bResultTy = resultTy.toBuiltinTensor();
if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize)
return nullptr;
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
if (!dimAttr)
return nullptr;
auto dim = dimAttr.getValue().getSExtValue();
dim += dim < 0 ? bResultTy.getRank() : 0;
for (int i = 0, s = bResultTy.getRank(); i < s; ++i) {
if (i == dim)
continue;
if (bResultTy.getDimSize(i) != 1)
return nullptr;
}
llvm::SmallVector<Attribute> values;
for (auto operand : list.getOperands()) {
DenseElementsAttr dattr;
if (!matchPattern(operand, m_Constant(&dattr)))
return nullptr;
auto oty = dyn_cast<RankedTensorType>(dattr.getType());
if (!oty)
return nullptr;
if (dattr.isSplat()) {
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
values.push_back(dattr.getSplatValue<Attribute>());
} else {
auto evals = dattr.getValues<Attribute>();
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
values.push_back(evals[i]);
}
}
return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values);
}
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@ -2947,19 +2994,32 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!outType.hasDtype())
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
(!isAssumingStrictSymbolicShapes((*this)->getBlock()) &&
(!inType.areAllSizesKnown() || !outType.areAllSizesKnown())))
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;
auto inSizes = inType.getSizes();
auto outSizes = outType.getSizes();
if (inSizes.size() == outSizes.size()) {
bool sameSizes = true;
for (int i = 0, s = inSizes.size(); i < s; ++i)
sameSizes &= inSizes[i] == outSizes[i];
if (sameSizes)
return getOperand(0);
}
return getOperand(0);
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
if (!selfAttr)
return nullptr;
if (!selfAttr.isSplat())
return nullptr;
auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype());
return DenseElementsAttr::get(attrty, selfAttr.getSplatValue<Attribute>());
}
//===----------------------------------------------------------------------===//
@ -2995,23 +3055,44 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
outType.toBuiltinTensor().clone(inType.getDtype()),
input.getSplatValue<Attribute>());
// If the output is a single value we can index into a constant input and grab
// that single value:
if (input && start && dim &&
llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) {
bool unaryNonDim = true;
int64_t dimInt = dim.getValue().getSExtValue();
for (int i = 0, s = inType.getSizes().size(); i < s; ++i) {
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
}
if (unaryNonDim) {
int64_t idx = start.getValue().getSExtValue();
if (idx < 0)
idx += input.getNumElements();
Attribute value = input.getValues<Attribute>()[idx];
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), value);
}
int count = 1;
for (auto dim : outType.getSizes())
count = count * dim;
if (count == 0)
return {};
if (!dim)
return nullptr;
int64_t dimInt = dim.getValue().getSExtValue();
if (dimInt < 0)
dimInt += inType.getSizes().size();
bool unaryNonDim = true;
for (int i = 0, s = outType.getSizes().size(); i < s; ++i)
unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt;
// Fold the slice if the output tensor is relatively small, currently
// coded to 16:
if (input && start && step && dim && count < 16 && unaryNonDim &&
count < 16) {
int64_t inCount = input.getNumElements();
int64_t begin = start.getValue().getSExtValue();
int64_t stride = step.getValue().getSExtValue();
if (stride < 1)
return {};
int64_t limit = end.getValue().getSExtValue();
begin = begin < 0 ? begin + inCount : begin;
limit = limit < 0 ? limit + inCount : limit;
limit = limit < 0 ? inType.getSizes()[dimInt] : limit;
limit = std::min(limit, inType.getSizes()[dimInt]);
llvm::SmallVector<Attribute> values;
for (int i = begin; i < limit; i += stride)
values.push_back(input.getValues<Attribute>()[i]);
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), values);
}
// If the input and output shapes are the same we can just fold:

View File

@ -1,4 +1,4 @@
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
// RUN: torch-mlir-opt %s -canonicalize --split-input-file | FileCheck %s
// CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
@ -1990,6 +1990,7 @@ func.func @torch.aten.sort$nofold (%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1
return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64>
}
// -----
// CHECK-LABEL: @torch.aten.cat$fold_single_operand
// CHECK-SAME: %[[ARG0:.+]]: !torch.tensor
@ -2001,6 +2002,22 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te
return %1: !torch.tensor
}
// -----
// CHECK-LABEL: @torch.aten.cat$fold_zero_dim_operand
// CHECK: %[[FOLD:.+]] = torch.vtensor.literal(dense<[1, 3, 2, 2]> : tensor<4xsi32>)
// CHECK: return %[[FOLD]] : !torch.vtensor
func.func @torch.aten.cat$fold_zero_dim_operand() -> !torch.vtensor<[4],si32> {
%0 = torch.vtensor.literal(dense<[1, 3]> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
%1 = torch.vtensor.literal(dense<2> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
%int0 = torch.constant.int 0
%list = torch.prim.ListConstruct %0, %1 : (!torch.vtensor<[2],si32>, !torch.vtensor<[2],si32>) -> !torch.list<vtensor>
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],si32>
return %cat: !torch.vtensor<[4],si32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
@ -2013,15 +2030,23 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !
return %0 : !torch.vtensor<[3,4,2],f32>
}
// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32>
func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
%list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
// -----
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold_splat
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3.000000e+00> : tensor<3x4x2xf32>) : !torch.vtensor<[3,4,2],f32>
// CHECK: return %[[CST]]
func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {
%tensor = torch.vtensor.literal(dense<3.0> : tensor<1x4x1xf32>) : !torch.vtensor<[1,4,1],f32>
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int2 = torch.constant.int 2
%list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.broadcast_to %tensor, %list : !torch.vtensor<[1,4,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,2],f32>
return %0 : !torch.vtensor<[3,4,2],f32>
}
// -----
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
@ -2078,11 +2103,21 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>,
// -----
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_small() -> !torch.vtensor<[2],si32> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[3, 5]> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
// CHECK: return %[[CST]]
func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
%tensor = torch.vtensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
%dim = torch.constant.int 0
%int3 = torch.constant.int 3
%int2 = torch.constant.int 2
%int7 = torch.constant.int 7
%0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int7, %int2 : !torch.vtensor<[10], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
return %0 : !torch.vtensor<[2],si32>
}
// -----
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
%int0 = torch.constant.int 0
@ -2097,7 +2132,7 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>