diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ac03e1d12..e586f5499 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -826,8 +826,9 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, }; comparator = newComparator; } - // Fold comparisons of negative values with the result of AtenSizeIntOp, which - // is known to always be non-negative. + + // Fold comparisons of AtenSizeIntOp against negative values. + // AtenSizeIntOp is known to always be non-negative. if (rhsIsConstant && rhs < 0) { // We can return `comparator(0, -1)` here because of the property: // If x >= 0 && y < 0, then: @@ -837,10 +838,20 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, if (auto size = lhsValue.getDefiningOp()) return getI1IntegerAttr(op->getContext(), comparator(0, -1)); } - // A special case of importance: size.int >= 0 ==> True. - if (rhsIsConstant && rhs == 0 && isa(op)) { - if (auto size = lhsValue.getDefiningOp()) - return getI1IntegerAttr(op->getContext(), true); + + // Fold comparisons of AtenSizeIntOp against 0: + // - torch.aten.size.int >= 0 ==> True. + // - torch.aten.size.int < 0 ==> False. + // (and the operand-swapped versions of the above) + if (rhsIsConstant && rhs == 0) { + if (auto size = lhsValue.getDefiningOp()) { + // >= 0 comparison. + if (comparator(0, 0) && comparator(1, 0)) + return getI1IntegerAttr(op->getContext(), true); + // < 0 comparison. + if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0)) + return getI1IntegerAttr(op->getContext(), false); + } } return nullptr; diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py index 4b79de4be..f62b5aa06 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/upstream_shape_helpers.py @@ -318,7 +318,9 @@ def slice(self: List[int], dim: int, start: Optional[int], end: Optional[int], s end_val += self[dim] if start_val < 0: start_val = 0 - elif start_val >= self[dim]: + # TODO: Remove this comment after https://github.com/pytorch/pytorch/pull/74980 + # is merged to incorporate our local edit here. + elif start_val > self[dim]: start_val = self[dim] if end_val < start_val: end_val = start_val diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index cd2527b23..ef0d45e35 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -375,17 +375,39 @@ func @torch.aten.gt.float$evaluate_to_false() -> !torch.bool { } -// CHECK-LABEL: func @torch.aten.ge.int$of_size.int( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: return %[[TRUE]] : !torch.bool -func @torch.aten.ge.int$of_size.int(%arg0: !torch.tensor) -> !torch.bool { +// CHECK-LABEL: func @comparison_with_torch.aten.size.int( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) { +// CHECK: %[[SIZE:.*]] = torch.aten.size.int %[[ARG0]], %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int +// CHECK: %[[GE_0_LHS:.*]] = torch.aten.ge.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[LT_0_LHS:.*]] = torch.aten.lt.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[EQ_0_LHS:.*]] = torch.aten.eq.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[NE_0_LHS:.*]] = torch.aten.ne.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[GT_0_RHS:.*]] = torch.aten.gt.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[LE_0_RHS:.*]] = torch.aten.le.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[EQ_0_RHS:.*]] = torch.aten.eq.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[NE_0_RHS:.*]] = torch.aten.ne.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool +// CHECK: return %true, %true, %false, %false, %[[GE_0_LHS]], %[[LT_0_LHS]], %[[EQ_0_LHS]], %[[NE_0_LHS]], %[[GT_0_RHS]], %[[LE_0_RHS]], %[[EQ_0_RHS]], %[[NE_0_RHS]] : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool +func @comparison_with_torch.aten.size.int(%arg0: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) { %int0 = torch.constant.int 0 - %0 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int - %1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - return %1 : !torch.bool + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int + // Cases we can fold. + %1 = torch.aten.le.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %3 = torch.aten.lt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.gt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + // Cases we cannot fold. + %5 = torch.aten.ge.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + %6 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + %7 = torch.aten.eq.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + %8 = torch.aten.ne.int %int0, %0 : !torch.int, !torch.int -> !torch.bool + %9 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %10 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %11 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %12 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool } + // CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool { // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: return %[[FALSE]] : !torch.bool