mirror of https://github.com/llvm/torch-mlir
[torch] AtenSliceOp folder that produces splat results (#2869)
Includes `slice` folder and lit tests --------- Co-authored-by: Xida Ren <xida.ren.dev@gmail.com>revert-2869-AtenSlice-folder
parent
723b8b1d28
commit
fc04bc7ee9
|
@ -2552,22 +2552,52 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||
int64_t start, end, step;
|
||||
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
|
||||
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
|
||||
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
|
||||
start == 0 && end == std::numeric_limits<int64_t>::max())
|
||||
DenseElementsAttr input =
|
||||
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
IntegerAttr start = dyn_cast_or_null<IntegerAttr>(adaptor.getStart());
|
||||
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
||||
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
||||
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||
|
||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||
start.getValue().getSExtValue() == 0 &&
|
||||
end.getValue().getSExtValue() == std::numeric_limits<int64_t>::max())
|
||||
return getOperand(0);
|
||||
|
||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||
if (inType != outType)
|
||||
return nullptr;
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||
!inType.hasDtype() || !outType.hasDtype() ||
|
||||
inType.getDtype() != outType.getDtype())
|
||||
return nullptr;
|
||||
|
||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
||||
return nullptr;
|
||||
|
||||
if (input && input.isSplat())
|
||||
return DenseElementsAttr::get(
|
||||
outType.toBuiltinTensor().clone(inType.getDtype()),
|
||||
input.getSplatValue<Attribute>());
|
||||
|
||||
// If the output is a single value we can index into a constant input and grab
|
||||
// that single value:
|
||||
if (input && start && dim &&
|
||||
llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) {
|
||||
bool unaryNonDim = true;
|
||||
int64_t dimInt = dim.getValue().getSExtValue();
|
||||
for (int i = 0, s = inType.getSizes().size(); i < s; ++i) {
|
||||
unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt;
|
||||
}
|
||||
if (unaryNonDim) {
|
||||
Attribute value =
|
||||
input.getValues<Attribute>()[start.getValue().getSExtValue()];
|
||||
return DenseElementsAttr::get(
|
||||
outType.toBuiltinTensor().clone(inType.getDtype()), value);
|
||||
}
|
||||
}
|
||||
|
||||
// If the input and output shapes are the same we can just fold:
|
||||
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||
return nullptr;
|
||||
|
|
|
@ -2104,6 +2104,53 @@ func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>,
|
|||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>) {
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<50> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64>
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<70> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64>
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: return %[[RET_0]], %[[RET_1]]
|
||||
func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, !torch.vtensor<[1, 1],si64>) {
|
||||
%tensor = torch.vtensor.literal(dense<[[10,20,30,40,50,60,70,80,90,100]]> : tensor<1x10xsi64>) : !torch.vtensor<[1, 10],si64>
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int4 = torch.constant.int 4
|
||||
%int5 = torch.constant.int 5
|
||||
%int6 = torch.constant.int 6
|
||||
%int7 = torch.constant.int 7
|
||||
%dim = torch.constant.int 1
|
||||
%0 = torch.aten.slice.Tensor %tensor, %dim, %int4, %int5, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64>
|
||||
%1 = torch.aten.slice.Tensor %tensor, %dim, %int6, %int7, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64>
|
||||
return %0, %1 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) {
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
|
||||
// CHECK-NOT: torch.aten.slice.Tensor
|
||||
// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>
|
||||
func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) {
|
||||
%tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32>
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%int3 = torch.constant.int 3
|
||||
%int4 = torch.constant.int 4
|
||||
%int5 = torch.constant.int 5
|
||||
%int6 = torch.constant.int 6
|
||||
%dim = torch.constant.int 0
|
||||
%0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
|
||||
%1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32>
|
||||
return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> {
|
||||
// CHECK: %int-1 = torch.constant.int -1
|
||||
// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>
|
||||
|
|
Loading…
Reference in New Issue