Improve aten.broadcast_to folder when in strict symbol mode (#2504)

Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts
never occur. This allows us to strengthen the folder for broadcasts to
cases where the rank is the same and all shapes match (including dynamic
sentinel values).
pull/2506/head
Quinn Dawkins 2023-10-05 09:02:10 -04:00 committed by GitHub
parent 14e6da8588
commit ae72eec224
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -2371,7 +2371,8 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
return nullptr;
if (inType.getSizes().size() != outType.getSizes().size() ||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
(!isAssumingStrictSymbolicShapes((*this)->getBlock()) &&
(!inType.areAllSizesKnown() || !outType.areAllSizesKnown())))
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])

View File

@ -1983,6 +1983,15 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !
return %0 : !torch.vtensor<[3,4,2],f32>
}
// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32>
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32>
func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
%list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
return %0 : !torch.vtensor<[?],f32>
}
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>