diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 98d0369d9..1857aff4d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2552,52 +2552,22 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - DenseElementsAttr input = - dyn_cast_or_null(adaptor.getSelf()); - IntegerAttr start = dyn_cast_or_null(adaptor.getStart()); - IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); - IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); - IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); - - if (start && end && step && step.getValue().getSExtValue() == 1 && - start.getValue().getSExtValue() == 0 && - end.getValue().getSExtValue() == std::numeric_limits::max()) + int64_t start, end, step; + if (matchPattern(getStart(), m_TorchConstantInt(&start)) && + matchPattern(getEnd(), m_TorchConstantInt(&end)) && + matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 && + start == 0 && end == std::numeric_limits::max()) return getOperand(0); - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || - !inType.hasDtype() || !outType.hasDtype() || - inType.getDtype() != outType.getDtype()) + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (inType != outType) + return nullptr; + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; - - if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); - - // If the output is a single value we can index into a constant input and grab - // that single value: - if (input && start && dim && - llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) { - bool unaryNonDim = true; - int64_t dimInt = dim.getValue().getSExtValue(); - for (int i = 0, s = inType.getSizes().size(); i < s; ++i) { - unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt; - } - if (unaryNonDim) { - Attribute value = - input.getValues()[start.getValue().getSExtValue()]; - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), value); - } - } - - // If the input and output shapes are the same we can just fold: for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 1f0d7971a..77a9e8ad3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2104,53 +2104,6 @@ func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, return %0 : !torch.vtensor<[?],f32> } -// ----- -// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>) { -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<50> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<70> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: return %[[RET_0]], %[[RET_1]] -func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, !torch.vtensor<[1, 1],si64>) { - %tensor = torch.vtensor.literal(dense<[[10,20,30,40,50,60,70,80,90,100]]> : tensor<1x10xsi64>) : !torch.vtensor<[1, 10],si64> - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int4 = torch.constant.int 4 - %int5 = torch.constant.int 5 - %int6 = torch.constant.int 6 - %int7 = torch.constant.int 7 - %dim = torch.constant.int 1 - %0 = torch.aten.slice.Tensor %tensor, %dim, %int4, %int5, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> - %1 = torch.aten.slice.Tensor %tensor, %dim, %int6, %int7, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> - return %0, %1 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64> -} - - -// ----- -// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> -func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { - %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %int5 = torch.constant.int 5 - %int6 = torch.constant.int 6 - %dim = torch.constant.int 0 - %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> - %1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> - return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> -} - - - // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %int-1 = torch.constant.int -1 // CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64>