mirror of https://github.com/llvm/torch-mlir
[torhc] aten.index_select folder (#2871)
Folds aten::index_select ops under the following conditions: 1. If the input and output are the same shape, the indexing operation is a NOP, so just return the input. 2. If the input has shape <1x1x...xNx...x1> (all 1's except for one dim), and the output shape is <1x1x...x1> (all 1's), then there is a single index, so extract the single element value and return a tensor with that value. --------- Co-authored-by: Dave Liddell <dliddell@xilinx.com>pull/2891/head
parent
32dbf99ce2
commit
23647ab2d1
|
@ -9785,6 +9785,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [
|
||||||
printDefaultTorchOp(printer, *this, 3, 1);
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [
|
def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [
|
||||||
|
|
|
@ -294,6 +294,44 @@ bool isListPotentiallyMutated(Value list);
|
||||||
/// the list.
|
/// the list.
|
||||||
bool potentiallyMutatesListOperands(Operation *op);
|
bool potentiallyMutatesListOperands(Operation *op);
|
||||||
|
|
||||||
|
/// Returns the value from an `IntegerAttr` as an `int64_t`.
|
||||||
|
///
|
||||||
|
/// @param intAttr the `IntegerAttr` from which to extract the value
|
||||||
|
/// @return the value as an `int64_t`
|
||||||
|
///
|
||||||
|
/// Regardless of the signed-ness of the attribute, this function returns
|
||||||
|
/// the value as a signed integer, which implies that if the attribute has
|
||||||
|
/// a 64-bit unsigned value, it will be converted to an int64_t in the manner
|
||||||
|
/// that uint64_t is cast to int64_t in C++.
|
||||||
|
inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) {
|
||||||
|
if (intAttr.getType().isUnsignedInteger())
|
||||||
|
return intAttr.getValue().getZExtValue();
|
||||||
|
return intAttr.getValue().getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the value from an `IntegerAttr` as an integral index.
|
||||||
|
///
|
||||||
|
/// @param intAttr the `IntegerAttr` from which to extract the index
|
||||||
|
/// @param dimSize the size of the dimension that the attribute indexes into
|
||||||
|
/// @return the index value
|
||||||
|
///
|
||||||
|
/// Use this function when the given `IntegerAttr` represents an index into
|
||||||
|
/// a range, such as an index into a tensor dimension. If `dimSize` is given,
|
||||||
|
/// negative index values are converted into positive vales by counting
|
||||||
|
/// elements from the "right" side of the dimension, as in python, numpy, etc.
|
||||||
|
/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the
|
||||||
|
/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not
|
||||||
|
/// given, any negative indices are returned as negative numbers.
|
||||||
|
///
|
||||||
|
/// No bounds checking is performed on the index to ensure that it is within
|
||||||
|
/// the legal range for `dimSize`.
|
||||||
|
inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) {
|
||||||
|
int64_t signedIndex = getIntAttrAsSigned(intAttr);
|
||||||
|
if (dimSize < 0 || signedIndex > 0)
|
||||||
|
return signedIndex;
|
||||||
|
return dimSize + signedIndex; // count backwards from dimSize
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -2911,6 +2911,91 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenIndexSelectOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto self = getSelf();
|
||||||
|
auto index = getIndex();
|
||||||
|
auto selfTy = dyn_cast<ValueTensorType>(self.getType());
|
||||||
|
auto indexTy = dyn_cast<ValueTensorType>(index.getType());
|
||||||
|
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
||||||
|
if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() ||
|
||||||
|
!indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() ||
|
||||||
|
!indexTy.hasDtype() || !resultTy.hasDtype())
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
auto selfSizes = selfTy.getSizes();
|
||||||
|
auto indexSizes = indexTy.getSizes();
|
||||||
|
auto resultSizes = resultTy.getSizes();
|
||||||
|
|
||||||
|
if (selfTy.getDtype() != resultTy.getDtype() ||
|
||||||
|
selfSizes.size() != resultSizes.size() || indexSizes.size() != 1)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// If the selection results in a tensor of the same dimensions as the
|
||||||
|
// input, the selection must have specified every index of the input,
|
||||||
|
// so the result is exactly the same as the input.
|
||||||
|
|
||||||
|
bool fullTensor = true;
|
||||||
|
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
||||||
|
fullTensor &= selfSizes[i] == resultSizes[i];
|
||||||
|
fullTensor &= selfSizes[i] != Torch::kUnknownSize;
|
||||||
|
fullTensor &= resultSizes[i] != Torch::kUnknownSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fullTensor && indexSizes[0] == 1)
|
||||||
|
return self;
|
||||||
|
|
||||||
|
// If the input tensor, index dimension, or indexes are non-constant,
|
||||||
|
// can't fold.
|
||||||
|
|
||||||
|
auto selfAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||||
|
auto dimAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||||
|
auto indexAttr = dyn_cast_or_null<DenseElementsAttr>(adaptor.getIndex());
|
||||||
|
|
||||||
|
if (!selfAttr || !dimAttr || !indexAttr)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
// If the input's dimensions are all 1 except for one dimension, and if
|
||||||
|
// there is a single index in the index list (as detected by the result
|
||||||
|
// dimension being 1), then fold to a <1x1x...x1> tensor literal containing
|
||||||
|
// a single element. Handles float and int types.
|
||||||
|
|
||||||
|
int64_t dimInt = dimAttr.getInt();
|
||||||
|
// If the selected dim is negative, count backwards from the last dim
|
||||||
|
if (dimInt < 0)
|
||||||
|
dimInt = selfSizes.size() + dimInt;
|
||||||
|
assert(uint64_t(dimInt) < selfSizes.size() &&
|
||||||
|
"Selected dim > number of dims");
|
||||||
|
|
||||||
|
for (int i = 0, s = selfSizes.size(); i < s; ++i) {
|
||||||
|
if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1)
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the single index value for the selected dimension
|
||||||
|
auto splatValue = indexAttr.getSplatValue<IntegerAttr>();
|
||||||
|
int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]);
|
||||||
|
|
||||||
|
// Extract the single constant value from the input tensor and turn the
|
||||||
|
// extracted value into a single-element tensor of the output shape and dtype
|
||||||
|
auto splattr = selfAttr.getValues<Attribute>()[indexInt];
|
||||||
|
|
||||||
|
auto dty = resultTy.getDtype();
|
||||||
|
auto attrTy = resultTy.toBuiltinTensor().clone(dty);
|
||||||
|
if (auto floatAttr = dyn_cast<FloatAttr>(splattr))
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble()));
|
||||||
|
|
||||||
|
if (auto intAttr = dyn_cast<IntegerAttr>(splattr)) {
|
||||||
|
return DenseElementsAttr::get(attrTy,
|
||||||
|
IntegerAttr::get(dty, intAttr.getValue()));
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenItemOp
|
// AtenItemOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -616,7 +616,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||||
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)")
|
||||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)")
|
||||||
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
emit("aten::item : (Tensor) -> (Scalar)", has_folder=True)
|
||||||
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -2280,3 +2280,55 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor
|
||||||
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
|
%1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor
|
||||||
return %1 : !torch.tensor
|
return %1 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select$noop(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64>
|
||||||
|
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64>
|
||||||
|
func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> {
|
||||||
|
%0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64>
|
||||||
|
return %0 : !torch.vtensor<[1,2,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si(
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
|
||||||
|
func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> {
|
||||||
|
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
|
||||||
|
return %0 : !torch.vtensor<[1],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui(
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64>
|
||||||
|
func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> {
|
||||||
|
%tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
|
||||||
|
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64>
|
||||||
|
return %0 : !torch.vtensor<[1],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui(
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
|
||||||
|
func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> {
|
||||||
|
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
|
||||||
|
%dim = torch.constant.int 0
|
||||||
|
%index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64>
|
||||||
|
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg(
|
||||||
|
// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32>
|
||||||
|
func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> {
|
||||||
|
%tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32>
|
||||||
|
%dim = torch.constant.int -1
|
||||||
|
%index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||||
|
%0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32>
|
||||||
|
return %0 : !torch.vtensor<[1],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue