mirror of https://github.com/llvm/torch-mlir
[torhc] aten.index_select folder (#2871)
Folds aten::index_select ops under the following conditions: 1. If the input and output are the same shape, the indexing operation is a NOP, so just return the input. 2. If the input has shape <1x1x...xNx...x1> (all 1's except for one dim), and the output shape is <1x1x...x1> (all 1's), then there is a single index, so extract the single element value and return a tensor with that value. --------- Co-authored-by: Dave Liddell <dliddell@xilinx.com>pull/2891/head
parent
32dbf99ce2
commit
23647ab2d1
|
@ -9785,6 +9785,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [
|
||||
|
|
|
@ -294,6 +294,44 @@ bool isListPotentiallyMutated(Value list);
|
|||
/// the list.
|
||||
bool potentiallyMutatesListOperands(Operation *op);
|
||||
|
||||
/// Returns the value from an `IntegerAttr` as an `int64_t`.
|
||||
///
|
||||
/// @param intAttr the `IntegerAttr` from which to extract the value
|
||||
/// @return the value as an `int64_t`
|
||||
///
|
||||
/// Regardless of the signed-ness of the attribute, this function returns
|
||||
/// the value as a signed integer, which implies that if the attribute has
|
||||
/// a 64-bit unsigned value, it will be converted to an int64_t in the manner
|
||||
/// that uint64_t is cast to int64_t in C++.
|
||||
inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
|
||||
if (intAttr.getType().isUnsignedInteger())
|
||||
return intAttr.getValue().getZExtValue();
|
||||
return intAttr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
/// Returns the value from an `IntegerAttr` as an integral index.
|
||||
///
|
||||
/// @param intAttr the `IntegerAttr` from which to extract the index
|
||||
/// @param dimSize the size of the dimension that the attribute indexes into
|
||||
/// @return the index value
|
||||
///
|
||||
/// Use this function when the given `IntegerAttr` represents an index into
|
||||
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
|
||||
/// negative index values are converted into positive vales by counting
|
||||
/// elements from the "right" side of the dimension, as in python, numpy, etc.
|
||||
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
|
||||
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
|
||||
/// given, any negative indices are returned as negative numbers.
|
||||
///
|
||||
/// No bounds checking is performed on the index to ensure that it is within
|
||||
/// the legal range for `dimSize`.
|
||||
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
|
||||
int64_t signedIndex = getIntAttrAsSigned(intAttr);
|
||||
if (dimSize < 0 || signedIndex > 0)
|
||||
return signedIndex;
|
||||
return dimSize + signedIndex; // count backwards from dimSize
|
||||
}
|
||||
|
||||
} // namespace Torch
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -2911,6 +2911,91 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIndexSelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
||||
auto self = getSelf();
|
||||
auto index = getIndex();
|
||||
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
|
||||
auto indexTy = dyn_cast<ValueTensorType>(index.getType());
|
||||
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
||||
if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() ||
|
||||
!indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() ||
|
||||
!indexTy.hasDtype() || !resultTy.hasDtype())
|
||||
return nullptr;
|
||||
|
||||
auto selfSizes = selfTy.getSizes();
|
||||
auto indexSizes = indexTy.getSizes();
|
||||
auto resultSizes = resultTy.getSizes();
|
||||
|
||||
if (selfTy.getDtype() != resultTy.getDtype() ||
|
||||
selfSizes.size() != resultSizes.size() || indexSizes.size() != 1)
|
||||
return nullptr;
|
||||
|
||||
// If the selection results in a tensor of the same dimensions as the
|
||||
// input, the selection must have specified every index of the input,
|
||||
// so the result is exactly the same as the input.
|
||||
|
||||
bool fullTensor = true;
|
||||
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
||||
fullTensor &= selfSizes[i] == resultSizes[i];
|
||||
fullTensor &= selfSizes[i] != Torch::kUnknownSize;
|
||||
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
|
||||
}
|
||||
|
||||
if (fullTensor && indexSizes[0] == 1)
|
||||
return self;
|
||||
|
||||
// If the input tensor, index dimension, or indexes are non-constant,
|
||||
// can't fold.
|
||||
|
||||
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||
auto indexAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getIndex());
|
||||
|
||||
if (!selfAttr || !dimAttr || !indexAttr)
|
||||
return {};
|
||||
|
||||
// If the input's dimensions are all 1 except for one dimension, and if
|
||||
// there is a single index in the index list (as detected by the result
|
||||
// dimension being 1), then fold to a <1x1x...x1> tensor literal containing
|
||||
// a single element. Handles float and int types.
|
||||
|
||||
int64_t dimInt = dimAttr.getInt();
|
||||
// If the selected dim is negative, count backwards from the last dim
|
||||
if (dimInt < 0)
|
||||
dimInt = selfSizes.size() + dimInt;
|
||||
assert(uint64_t(dimInt) < selfSizes.size() &&
|
||||
"Selected dim > number of dims");
|
||||
|
||||
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
||||
if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Get the single index value for the selected dimension
|
||||
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
|
||||
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
|
||||
|
||||
// Extract the single constant value from the input tensor and turn the
|
||||
// extracted value into a single-element tensor of the output shape and dtype
|
||||
auto splattr = selfAttr.getValues<Attribute>()[indexInt];
|
||||
|
||||
auto dty = resultTy.getDtype();
|
||||
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
||||
return DenseElementsAttr::get(
|
||||
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
||||
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(splattr)) {
|
||||
return DenseElementsAttr::get(attrTy,
|
||||
IntegerAttr::get(dty, intAttr.getValue()));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenItemOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -616,7 +616,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -2280,3 +2280,55 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor
|
|||
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$noop(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64>
|
||||
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> {
|
||||
%0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64>
|
||||
return %0 : !torch.vtensor<[1,2,3],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si(
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
|
||||
func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> {
|
||||
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
|
||||
%dim = torch.constant.int 0
|
||||
%index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||
return %0 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui(
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
|
||||
func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> {
|
||||
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
|
||||
%dim = torch.constant.int 0
|
||||
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
|
||||
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64>
|
||||
return %0 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui(
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
|
||||
func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> {
|
||||
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
|
||||
%dim = torch.constant.int 0
|
||||
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
|
||||
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg(
|
||||
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
|
||||
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
|
||||
func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> {
|
||||
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
|
||||
%dim = torch.constant.int -1
|
||||
%index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32>
|
||||
return %0 : !torch.vtensor<[1],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue