mirror of https://github.com/llvm/torch-mlir
[torch] Additional folders for shape computations (#2972)
A handful of operations are commonly used in shape calculations (slice, concat, broadcast). Added these additional folders to better propagate simple shape computations.pull/2978/head
parent
09875fabd1
commit
a86e89ecb5
|
@ -2899,12 +2899,59 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) {
|
||||||
|
// We set a maximum folding size of 16. This is a reasonable upper limit
|
||||||
|
// for shape computations.
|
||||||
|
constexpr int64_t kMaxFoldSize = 16;
|
||||||
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
auto list = getOperand(0).getDefiningOp<PrimListConstructOp>();
|
||||||
if (!list || list.getElements().size() != 1)
|
if (!list)
|
||||||
return nullptr;
|
|
||||||
if (list.getElements()[0].getType() != getResult().getType())
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
|
auto elements = list.getElements();
|
||||||
|
if (elements.size() == 1 && elements[0].getType() == getResult().getType())
|
||||||
return list.getElements()[0];
|
return list.getElements()[0];
|
||||||
|
|
||||||
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
||||||
|
if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype())
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
auto bResultTy = resultTy.toBuiltinTensor();
|
||||||
|
if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||||
|
if (!dimAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto dim = dimAttr.getValue().getSExtValue();
|
||||||
|
dim += dim < 0 ? bResultTy.getRank() : 0;
|
||||||
|
|
||||||
|
for (int i = 0, s = bResultTy.getRank(); i < s; ++i) {
|
||||||
|
if (i == dim)
|
||||||
|
continue;
|
||||||
|
if (bResultTy.getDimSize(i) != 1)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Attribute> values;
|
||||||
|
for (auto operand : list.getOperands()) {
|
||||||
|
DenseElementsAttr dattr;
|
||||||
|
if (!matchPattern(operand, m_Constant(&dattr)))
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
auto oty = dyn_cast<RankedTensorType>(dattr.getType());
|
||||||
|
if (!oty)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
if (dattr.isSplat()) {
|
||||||
|
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
|
||||||
|
values.push_back(dattr.getSplatValue<Attribute>());
|
||||||
|
} else {
|
||||||
|
auto evals = dattr.getValues<Attribute>();
|
||||||
|
for (int i = 0, s = oty.getDimSize(dim); i < s; ++i)
|
||||||
|
values.push_back(evals[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
|
@ -2947,21 +2994,34 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
if (inType != outType)
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||||
|
!outType.hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
|
||||||
|
if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
|
||||||
(!isAssumingStrictSymbolicShapes((*this)->getBlock()) &&
|
auto inSizes = inType.getSizes();
|
||||||
(!inType.areAllSizesKnown() || !outType.areAllSizesKnown())))
|
auto outSizes = outType.getSizes();
|
||||||
return nullptr;
|
if (inSizes.size() == outSizes.size()) {
|
||||||
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
bool sameSizes = true;
|
||||||
if (inType.getSizes()[i] != outType.getSizes()[i])
|
for (int i = 0, s = inSizes.size(); i < s; ++i)
|
||||||
return nullptr;
|
sameSizes &= inSizes[i] == outSizes[i];
|
||||||
}
|
|
||||||
|
if (sameSizes)
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||||
|
if (!selfAttr)
|
||||||
|
return nullptr;
|
||||||
|
if (!selfAttr.isSplat())
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype());
|
||||||
|
return DenseElementsAttr::get(attrty, selfAttr.getSplatValue<Attribute>());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenSliceTensorOp
|
// AtenSliceTensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2995,23 +3055,44 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
outType.toBuiltinTensor().clone(inType.getDtype()),
|
outType.toBuiltinTensor().clone(inType.getDtype()),
|
||||||
input.getSplatValue<Attribute>());
|
input.getSplatValue<Attribute>());
|
||||||
|
|
||||||
// If the output is a single value we can index into a constant input and grab
|
int count = 1;
|
||||||
// that single value:
|
for (auto dim : outType.getSizes())
|
||||||
if (input && start && dim &&
|
count = count * dim;
|
||||||
llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) {
|
|
||||||
bool unaryNonDim = true;
|
if (count == 0)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (!dim)
|
||||||
|
return nullptr;
|
||||||
int64_t dimInt = dim.getValue().getSExtValue();
|
int64_t dimInt = dim.getValue().getSExtValue();
|
||||||
for (int i = 0, s = inType.getSizes().size(); i < s; ++i) {
|
if (dimInt < 0)
|
||||||
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
|
dimInt += inType.getSizes().size();
|
||||||
}
|
|
||||||
if (unaryNonDim) {
|
bool unaryNonDim = true;
|
||||||
int64_t idx = start.getValue().getSExtValue();
|
for (int i = 0, s = outType.getSizes().size(); i < s; ++i)
|
||||||
if (idx < 0)
|
unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt;
|
||||||
idx += input.getNumElements();
|
|
||||||
Attribute value = input.getValues<Attribute>()[idx];
|
// Fold the slice if the output tensor is relatively small, currently
|
||||||
|
// coded to 16:
|
||||||
|
if (input && start && step && dim && count < 16 && unaryNonDim &&
|
||||||
|
count < 16) {
|
||||||
|
int64_t inCount = input.getNumElements();
|
||||||
|
int64_t begin = start.getValue().getSExtValue();
|
||||||
|
int64_t stride = step.getValue().getSExtValue();
|
||||||
|
if (stride < 1)
|
||||||
|
return {};
|
||||||
|
int64_t limit = end.getValue().getSExtValue();
|
||||||
|
begin = begin < 0 ? begin + inCount : begin;
|
||||||
|
limit = limit < 0 ? limit + inCount : limit;
|
||||||
|
limit = limit < 0 ? inType.getSizes()[dimInt] : limit;
|
||||||
|
limit = std::min(limit, inType.getSizes()[dimInt]);
|
||||||
|
|
||||||
|
llvm::SmallVector<Attribute> values;
|
||||||
|
for (int i = begin; i < limit; i += stride)
|
||||||
|
values.push_back(input.getValues<Attribute>()[i]);
|
||||||
|
|
||||||
return DenseElementsAttr::get(
|
return DenseElementsAttr::get(
|
||||||
outType.toBuiltinTensor().clone(inType.getDtype()), value);
|
outType.toBuiltinTensor().clone(inType.getDtype()), values);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the input and output shapes are the same we can just fold:
|
// If the input and output shapes are the same we can just fold:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
// RUN: torch-mlir-opt %s -canonicalize --split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
|
// CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) {
|
||||||
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
|
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
@ -1990,6 +1990,7 @@ func.func @torch.aten.sort$nofold (%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1
|
||||||
return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64>
|
return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.cat$fold_single_operand
|
// CHECK-LABEL: @torch.aten.cat$fold_single_operand
|
||||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.tensor
|
// CHECK-SAME: %[[ARG0:.+]]: !torch.tensor
|
||||||
|
@ -2001,6 +2002,22 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te
|
||||||
return %1: !torch.tensor
|
return %1: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @torch.aten.cat$fold_zero_dim_operand
|
||||||
|
// CHECK: %[[FOLD:.+]] = torch.vtensor.literal(dense<[1, 3, 2, 2]> : tensor<4xsi32>)
|
||||||
|
// CHECK: return %[[FOLD]] : !torch.vtensor
|
||||||
|
func.func @torch.aten.cat$fold_zero_dim_operand() -> !torch.vtensor<[4],si32> {
|
||||||
|
%0 = torch.vtensor.literal(dense<[1, 3]> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
|
||||||
|
%1 = torch.vtensor.literal(dense<2> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%list = torch.prim.ListConstruct %0, %1 : (!torch.vtensor<[2],si32>, !torch.vtensor<[2],si32>) -> !torch.list<vtensor>
|
||||||
|
%cat = torch.aten.cat %list, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[4],si32>
|
||||||
|
return %cat: !torch.vtensor<[4],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold(
|
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> {
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
|
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32>
|
||||||
|
@ -2013,15 +2030,23 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !
|
||||||
return %0 : !torch.vtensor<[3,4,2],f32>
|
return %0 : !torch.vtensor<[3,4,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold(
|
// -----
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32>
|
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32>
|
// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold_splat
|
||||||
func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
|
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3.000000e+00> : tensor<3x4x2xf32>) : !torch.vtensor<[3,4,2],f32>
|
||||||
%list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
// CHECK: return %[[CST]]
|
||||||
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
|
func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> {
|
||||||
return %0 : !torch.vtensor<[?],f32>
|
%tensor = torch.vtensor.literal(dense<3.0> : tensor<1x4x1xf32>) : !torch.vtensor<[1,4,1],f32>
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%0 = torch.aten.broadcast_to %tensor, %list : !torch.vtensor<[1,4,1],f32>, !torch.list<int> -> !torch.vtensor<[3,4,2],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
||||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
||||||
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
||||||
|
@ -2078,11 +2103,21 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>,
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
|
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_small() -> !torch.vtensor<[2],si32> {
|
||||||
// CHECK-NOT: torch.aten.slice.Tensor
|
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[3, 5]> : tensor<2xsi32>) : !torch.vtensor<[2],si32>
|
||||||
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
// CHECK: return %[[CST]]
|
||||||
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) {
|
||||||
// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>
|
%tensor = torch.vtensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int7 = torch.constant.int 7
|
||||||
|
%0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int7, %int2 : !torch.vtensor<[10], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
|
||||||
|
return %0 : !torch.vtensor<[2],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
|
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
|
||||||
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
|
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
|
||||||
%int0 = torch.constant.int 0
|
%int0 = torch.constant.int 0
|
||||||
|
@ -2097,7 +2132,7 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>,
|
||||||
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
|
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||||
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
|
||||||
|
|
Loading…
Reference in New Issue