Fix for 0-size dim inferred incorrectly.

The issue was in the canonicalizer for torch.aten.ge.int -- in cases
where the operands were swapped, it would miscompile. This issue is
fixed and folding support generalized to `torch.aten.size.int < 0` as
well.

Fixes #716
pull/643/head
Sean Silva 2022-03-30 21:10:51 +00:00
parent 8250f50c81
commit c17c0a6ba2
3 changed files with 50 additions and 15 deletions

View File

@ -826,8 +826,9 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
};
comparator = newComparator;
}
// Fold comparisons of negative values with the result of AtenSizeIntOp, which
// is known to always be non-negative.
// Fold comparisons of AtenSizeIntOp against negative values.
// AtenSizeIntOp is known to always be non-negative.
if (rhsIsConstant && rhs < 0) {
// We can return `comparator(0, -1)` here because of the property:
// If x >= 0 && y < 0, then:
@ -837,10 +838,20 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
return getI1IntegerAttr(op->getContext(), comparator(0, -1));
}
// A special case of importance: size.int >= 0 ==> True.
if (rhsIsConstant && rhs == 0 && isa<AtenGeIntOp>(op)) {
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
return getI1IntegerAttr(op->getContext(), true);
// Fold comparisons of AtenSizeIntOp against 0:
// - torch.aten.size.int >= 0 ==> True.
// - torch.aten.size.int < 0 ==> False.
// (and the operand-swapped versions of the above)
if (rhsIsConstant && rhs == 0) {
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>()) {
// >= 0 comparison.
if (comparator(0, 0) && comparator(1, 0))
return getI1IntegerAttr(op->getContext(), true);
// < 0 comparison.
if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0))
return getI1IntegerAttr(op->getContext(), false);
}
}
return nullptr;

View File

@ -318,7 +318,9 @@ def slice(self: List[int], dim: int, start: Optional[int], end: Optional[int], s
end_val += self[dim]
if start_val < 0:
start_val = 0
elif start_val >= self[dim]:
# TODO: Remove this comment after https://github.com/pytorch/pytorch/pull/74980
# is merged to incorporate our local edit here.
elif start_val > self[dim]:
start_val = self[dim]
if end_val < start_val:
end_val = start_val

View File

@ -375,17 +375,39 @@ func @torch.aten.gt.float$evaluate_to_false() -> !torch.bool {
}
// CHECK-LABEL: func @torch.aten.ge.int$of_size.int(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: return %[[TRUE]] : !torch.bool
func @torch.aten.ge.int$of_size.int(%arg0: !torch.tensor) -> !torch.bool {
// CHECK-LABEL: func @comparison_with_torch.aten.size.int(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) {
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %[[ARG0]], %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int
// CHECK: %[[GE_0_LHS:.*]] = torch.aten.ge.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[LT_0_LHS:.*]] = torch.aten.lt.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[EQ_0_LHS:.*]] = torch.aten.eq.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[NE_0_LHS:.*]] = torch.aten.ne.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[GT_0_RHS:.*]] = torch.aten.gt.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[LE_0_RHS:.*]] = torch.aten.le.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[EQ_0_RHS:.*]] = torch.aten.eq.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[NE_0_RHS:.*]] = torch.aten.ne.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
// CHECK: return %true, %true, %false, %false, %[[GE_0_LHS]], %[[LT_0_LHS]], %[[EQ_0_LHS]], %[[NE_0_LHS]], %[[GT_0_RHS]], %[[LE_0_RHS]], %[[EQ_0_RHS]], %[[NE_0_RHS]] : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool
func @comparison_with_torch.aten.size.int(%arg0: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) {
%int0 = torch.constant.int 0
%0 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int
%1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
return %1 : !torch.bool
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int
// Cases we can fold.
%1 = torch.aten.le.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%3 = torch.aten.lt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%4 = torch.aten.gt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
// Cases we cannot fold.
%5 = torch.aten.ge.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
%6 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
%7 = torch.aten.eq.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
%8 = torch.aten.ne.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
%9 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%10 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%11 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%12 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool
}
// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: return %[[FALSE]] : !torch.bool