mirror of https://github.com/llvm/torch-mlir
Fix for 0-size dim inferred incorrectly.
The issue was in the canonicalizer for torch.aten.ge.int -- in cases where the operands were swapped, it would miscompile. This issue is fixed and folding support generalized to `torch.aten.size.int < 0` as well. Fixes #716pull/643/head
parent
8250f50c81
commit
c17c0a6ba2
|
@ -826,8 +826,9 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|||
};
|
||||
comparator = newComparator;
|
||||
}
|
||||
// Fold comparisons of negative values with the result of AtenSizeIntOp, which
|
||||
// is known to always be non-negative.
|
||||
|
||||
// Fold comparisons of AtenSizeIntOp against negative values.
|
||||
// AtenSizeIntOp is known to always be non-negative.
|
||||
if (rhsIsConstant && rhs < 0) {
|
||||
// We can return `comparator(0, -1)` here because of the property:
|
||||
// If x >= 0 && y < 0, then:
|
||||
|
@ -837,10 +838,20 @@ static OpFoldResult intComparatorFoldHelper(OpTy op,
|
|||
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
|
||||
return getI1IntegerAttr(op->getContext(), comparator(0, -1));
|
||||
}
|
||||
// A special case of importance: size.int >= 0 ==> True.
|
||||
if (rhsIsConstant && rhs == 0 && isa<AtenGeIntOp>(op)) {
|
||||
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>())
|
||||
return getI1IntegerAttr(op->getContext(), true);
|
||||
|
||||
// Fold comparisons of AtenSizeIntOp against 0:
|
||||
// - torch.aten.size.int >= 0 ==> True.
|
||||
// - torch.aten.size.int < 0 ==> False.
|
||||
// (and the operand-swapped versions of the above)
|
||||
if (rhsIsConstant && rhs == 0) {
|
||||
if (auto size = lhsValue.getDefiningOp<AtenSizeIntOp>()) {
|
||||
// >= 0 comparison.
|
||||
if (comparator(0, 0) && comparator(1, 0))
|
||||
return getI1IntegerAttr(op->getContext(), true);
|
||||
// < 0 comparison.
|
||||
if (!comparator(0, 0) && comparator(-1, 0) && !comparator(1, 0))
|
||||
return getI1IntegerAttr(op->getContext(), false);
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
|
@ -318,7 +318,9 @@ def slice(self: List[int], dim: int, start: Optional[int], end: Optional[int], s
|
|||
end_val += self[dim]
|
||||
if start_val < 0:
|
||||
start_val = 0
|
||||
elif start_val >= self[dim]:
|
||||
# TODO: Remove this comment after https://github.com/pytorch/pytorch/pull/74980
|
||||
# is merged to incorporate our local edit here.
|
||||
elif start_val > self[dim]:
|
||||
start_val = self[dim]
|
||||
if end_val < start_val:
|
||||
end_val = start_val
|
||||
|
|
|
@ -375,17 +375,39 @@ func @torch.aten.gt.float$evaluate_to_false() -> !torch.bool {
|
|||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ge.int$of_size.int(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor) -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: return %[[TRUE]] : !torch.bool
|
||||
func @torch.aten.ge.int$of_size.int(%arg0: !torch.tensor) -> !torch.bool {
|
||||
// CHECK-LABEL: func @comparison_with_torch.aten.size.int(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) {
|
||||
// CHECK: %[[SIZE:.*]] = torch.aten.size.int %[[ARG0]], %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int
|
||||
// CHECK: %[[GE_0_LHS:.*]] = torch.aten.ge.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[LT_0_LHS:.*]] = torch.aten.lt.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[EQ_0_LHS:.*]] = torch.aten.eq.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[NE_0_LHS:.*]] = torch.aten.ne.int %int0, %[[SIZE]] : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[GT_0_RHS:.*]] = torch.aten.gt.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[LE_0_RHS:.*]] = torch.aten.le.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[EQ_0_RHS:.*]] = torch.aten.eq.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: %[[NE_0_RHS:.*]] = torch.aten.ne.int %[[SIZE]], %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
// CHECK: return %true, %true, %false, %false, %[[GE_0_LHS]], %[[LT_0_LHS]], %[[EQ_0_LHS]], %[[NE_0_LHS]], %[[GT_0_RHS]], %[[LE_0_RHS]], %[[EQ_0_RHS]], %[[NE_0_RHS]] : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool
|
||||
func @comparison_with_torch.aten.size.int(%arg0: !torch.vtensor<[?,2],unk>) -> (!torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool) {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.tensor, !torch.int -> !torch.int
|
||||
%1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
return %1 : !torch.bool
|
||||
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,2],unk>, !torch.int -> !torch.int
|
||||
// Cases we can fold.
|
||||
%1 = torch.aten.le.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
%2 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%3 = torch.aten.lt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%4 = torch.aten.gt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
// Cases we cannot fold.
|
||||
%5 = torch.aten.ge.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
%6 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
%7 = torch.aten.eq.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
%8 = torch.aten.ne.int %int0, %0 : !torch.int, !torch.int -> !torch.bool
|
||||
%9 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%10 = torch.aten.le.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%11 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%12 = torch.aten.ne.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool, !torch.bool
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.float$different_value() -> !torch.bool {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: return %[[FALSE]] : !torch.bool
|
||||
|
|
Loading…
Reference in New Issue