mirror of https://github.com/llvm/torch-mlir
Improve aten.broadcast_to folder when in strict symbol mode (#2504)
Strict symbolic shapes allow us to assume numpy-style dynamic broadcasts never occur. This allows us to strengthen the folder for broadcasts to cases where the rank is the same and all shapes match (including dynamic sentinel values).pull/2506/head
parent
14e6da8588
commit
ae72eec224
|
@ -2371,7 +2371,8 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (inType.getSizes().size() != outType.getSizes().size() ||
|
if (inType.getSizes().size() != outType.getSizes().size() ||
|
||||||
!inType.areAllSizesKnown() || !outType.areAllSizesKnown())
|
(!isAssumingStrictSymbolicShapes((*this)->getBlock()) &&
|
||||||
|
(!inType.areAllSizesKnown() || !outType.areAllSizesKnown())))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
|
||||||
if (inType.getSizes()[i] != outType.getSizes()[i])
|
if (inType.getSizes()[i] != outType.getSizes()[i])
|
||||||
|
|
|
@ -1983,6 +1983,15 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !
|
||||||
return %0 : !torch.vtensor<[3,4,2],f32>
|
return %0 : !torch.vtensor<[3,4,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold(
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32>
|
||||||
|
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32>
|
||||||
|
func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} {
|
||||||
|
%list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>
|
||||||
|
%0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice
|
||||||
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32>
|
||||||
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32>
|
||||||
|
|
Loading…
Reference in New Issue