mirror of https://github.com/llvm/torch-mlir
Add AtenSliceTOp Canonicalization to SimplifyShapeCalculations pass (#3791)
Some ops were failing to infer the static component of partially dynamic shapes, and the cause was a missing aten.slice.t pattern. The lit test included here is an IR dump created before DropAbstractInterpCalculations for an unflatten op that was failing to infer shapes before the change.pull/3792/head
parent
edd1bbec46
commit
1e431c6a90
|
@ -198,6 +198,7 @@ class SimplifyShapeCalculationsPass
|
||||||
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
|
AtenSizeOp::getCanonicalizationPatterns(patterns, context);
|
||||||
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
|
AtenLenTOp::getCanonicalizationPatterns(patterns, context);
|
||||||
AtenAddTOp::getCanonicalizationPatterns(patterns, context);
|
AtenAddTOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
AtenSliceTOp::getCanonicalizationPatterns(patterns, context);
|
||||||
|
|
||||||
// TODO: Debug visitation order to make this more efficient.
|
// TODO: Debug visitation order to make this more efficient.
|
||||||
// A single linear scan should suffice.
|
// A single linear scan should suffice.
|
||||||
|
|
|
@ -489,3 +489,42 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt
|
||||||
|
|
||||||
return %arg0 : !torch.vtensor<[2],f32>
|
return %arg0 : !torch.vtensor<[2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @unflat_shape_partial_dyn
|
||||||
|
// CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768
|
||||||
|
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
|
||||||
|
// CHECK : } shapes {
|
||||||
|
// CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list<int>
|
||||||
|
// CHECK : } : !torch.vtensor<[?,?,4,768],f32>
|
||||||
|
func.func @unflat_shape_partial_dyn(%arg0: !torch.vtensor<[?,?,3072],f32>) -> !torch.vtensor<[?,?,4,?],f32> {
|
||||||
|
%int768 = torch.constant.int 768
|
||||||
|
%int3072 = torch.constant.int 3072
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%none = torch.constant.none
|
||||||
|
%int-1 = torch.constant.int -1
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.prim.ListConstruct %int4, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%1 = torch.shape.calculate {
|
||||||
|
%2 = torch.aten.unflatten.int %arg0, %int2, %0 : !torch.vtensor<[?,?,3072],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[?,?,4,?],f32>
|
||||||
|
torch.shape.calculate.yield %2 : !torch.vtensor<[?,?,4,?],f32>
|
||||||
|
} shapes {
|
||||||
|
%2 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
|
||||||
|
%3 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
|
||||||
|
%4 = torch.prim.ListConstruct %2, %3, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.prim.ListConstruct %int4, %int768 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%6 = torch.aten.slice.t %4, %none, %int2, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>
|
||||||
|
%7 = torch.aten.add.t %6, %5 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
%8 = torch.aten.slice.t %4, %int3, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>
|
||||||
|
%9 = torch.aten.add.t %7, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>
|
||||||
|
torch.shape.calculate.yield.shapes %9 : !torch.list<int>
|
||||||
|
} : !torch.vtensor<[?,?,4,?],f32>
|
||||||
|
return %1 : !torch.vtensor<[?,?,4,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue