[torch] AtenSliceOp folder that produces splat results (#2869)

Includes `slice` folder and lit tests

---------

Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>
revert-2869-AtenSlice-folder
Xida Ren (Cedar) 2024-02-07 14:00:46 -05:00 committed by GitHub
parent 723b8b1d28
commit fc04bc7ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 87 additions and 10 deletions

View File

@ -2552,22 +2552,52 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
int64_t start, end, step;
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
start == 0 && end == std::numeric_limits<int64_t>::max())
DenseElementsAttr input =
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
IntegerAttr start = dyn_cast_or_null<IntegerAttr>(adaptor.getStart());
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
if (start && end && step && step.getValue().getSExtValue() == 1 &&
start.getValue().getSExtValue() == 0 &&
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max())
return getOperand(0);
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
if (inType != outType)
return nullptr;
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
!inType.hasDtype() || !outType.hasDtype() ||
inType.getDtype() != outType.getDtype())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
return nullptr;
if (input && input.isSplat())
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()),
input.getSplatValue<Attribute>());
// If the output is a single value we can index into a constant input and grab
// that single value:
if (input && start && dim &&
llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) {
bool unaryNonDim = true;
int64_t dimInt = dim.getValue().getSExtValue();
for (int i = 0, s = inType.getSizes().size(); i < s; ++i) {
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
}
if (unaryNonDim) {
Attribute value =
input.getValues<Attribute>()[start.getValue().getSExtValue()];
return DenseElementsAttr::get(
outType.toBuiltinTensor().clone(inType.getDtype()), value);
}
}
// If the input and output shapes are the same we can just fold:
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;

View File

@ -2104,6 +2104,53 @@ func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>,
return %0 : !torch.vtensor<[?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>) {
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<50> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<70> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: return %[[RET_0]], %[[RET_1]]
func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, !torch.vtensor<[1, 1],si64>) {
%tensor = torch.vtensor.literal(dense<[[10,20,30,40,50,60,70,80,90,100]]> : tensor<1x10xsi64>) : !torch.vtensor<[1, 10],si64>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%int7 = torch.constant.int 7
%dim = torch.constant.int 1
%0 = torch.aten.slice.Tensor %tensor, %dim, %int4, %int5, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64>
%1 = torch.aten.slice.Tensor %tensor, %dim, %int6, %int7, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64>
return %0, %1 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>
}
// -----
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
// CHECK-NOT: torch.aten.slice.Tensor
// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%int4 = torch.constant.int 4
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%dim = torch.constant.int 0
%0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
%1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
}
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
// CHECK: %int-1 = torch.constant.int -1
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>