[torhc] aten.index_select folder (#2871)

Folds aten::index_select ops under the following conditions:

1. If the input and output are the same shape, the indexing operation is
a NOP, so just return the input.
2. If the input has shape <1x1x...xNx...x1> (all 1's except for one
dim), and the output shape is <1x1x...x1> (all 1's), then there is a
single index, so extract the single element value and return a tensor
with that value.

---------

Co-authored-by: Dave Liddell <dliddell@xilinx.com>
pull/2891/head
Dave Liddell 2024-02-07 17:17:15 -07:00 committed by GitHub
parent 32dbf99ce2
commit 23647ab2d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 1 deletions

View File

@ -9785,6 +9785,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}
def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [

View File

@ -294,6 +294,44 @@ bool isListPotentiallyMutated(Value list);
/// the list.
bool potentiallyMutatesListOperands(Operation *op);
/// Returns the value from an `IntegerAttr` as an `int64_t`.
///
/// @param intAttr the `IntegerAttr` from which to extract the value
/// @return the value as an `int64_t`
///
/// Regardless of the signed-ness of the attribute, this function returns
/// the value as a signed integer, which implies that if the attribute has
/// a 64-bit unsigned value, it will be converted to an int64_t in the manner
/// that uint64_t is cast to int64_t in C++.
inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
if (intAttr.getType().isUnsignedInteger())
return intAttr.getValue().getZExtValue();
return intAttr.getValue().getSExtValue();
}
/// Returns the value from an `IntegerAttr` as an integral index.
///
/// @param intAttr the `IntegerAttr` from which to extract the index
/// @param dimSize the size of the dimension that the attribute indexes into
/// @return the index value
///
/// Use this function when the given `IntegerAttr` represents an index into
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
/// negative index values are converted into positive vales by counting
/// elements from the "right" side of the dimension, as in python, numpy, etc.
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
/// given, any negative indices are returned as negative numbers.
///
/// No bounds checking is performed on the index to ensure that it is within
/// the legal range for `dimSize`.
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
int64_t signedIndex = getIntAttrAsSigned(intAttr);
if (dimSize < 0 || signedIndex > 0)
return signedIndex;
return dimSize + signedIndex; // count backwards from dimSize
}
} // namespace Torch
} // namespace torch
} // namespace mlir

View File

@ -2911,6 +2911,91 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenIndexSelectOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
auto self = getSelf();
auto index = getIndex();
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
auto indexTy = dyn_cast<ValueTensorType>(index.getType());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() ||
!indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() ||
!indexTy.hasDtype() || !resultTy.hasDtype())
return nullptr;
auto selfSizes = selfTy.getSizes();
auto indexSizes = indexTy.getSizes();
auto resultSizes = resultTy.getSizes();
if (selfTy.getDtype() != resultTy.getDtype() ||
selfSizes.size() != resultSizes.size() || indexSizes.size() != 1)
return nullptr;
// If the selection results in a tensor of the same dimensions as the
// input, the selection must have specified every index of the input,
// so the result is exactly the same as the input.
bool fullTensor = true;
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
fullTensor &= selfSizes[i] == resultSizes[i];
fullTensor &= selfSizes[i] != Torch::kUnknownSize;
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
}
if (fullTensor && indexSizes[0] == 1)
return self;
// If the input tensor, index dimension, or indexes are non-constant,
// can't fold.
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
auto indexAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getIndex());
if (!selfAttr || !dimAttr || !indexAttr)
return {};
// If the input's dimensions are all 1 except for one dimension, and if
// there is a single index in the index list (as detected by the result
// dimension being 1), then fold to a <1x1x...x1> tensor literal containing
// a single element. Handles float and int types.
int64_t dimInt = dimAttr.getInt();
// If the selected dim is negative, count backwards from the last dim
if (dimInt < 0)
dimInt = selfSizes.size() + dimInt;
assert(uint64_t(dimInt) < selfSizes.size() &&
"Selected dim > number of dims");
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1)
return nullptr;
}
// Get the single index value for the selected dimension
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
// Extract the single constant value from the input tensor and turn the
// extracted value into a single-element tensor of the output shape and dtype
auto splattr = selfAttr.getValues<Attribute>()[indexInt];
auto dty = resultTy.getDtype();
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
return DenseElementsAttr::get(
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
if (auto intAttr = dyn_cast<IntegerAttr>(splattr)) {
return DenseElementsAttr::get(attrTy,
IntegerAttr::get(dty, intAttr.getValue()));
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenItemOp
//===----------------------------------------------------------------------===//

View File

@ -616,7 +616,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")

View File

@ -2280,3 +2280,55 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
return %1 : !torch.tensor
}
// CHECK-LABEL: func.func @torch.aten.index_select$noop(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64>
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> {
%0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64>
return %0 : !torch.vtensor<[1,2,3],si64>
}
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> {
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> {
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64>
return %0 : !torch.vtensor<[1],si64>
}
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> {
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
%dim = torch.constant.int 0
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg(
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> {
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
%dim = torch.constant.int -1
%index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}