update file check

tanyo/slice_scatter_stage
TanyoKwok 2022-11-22 10:52:50 +08:00
parent 7aac35d51a
commit a5894dbf09
1 changed files with 23 additions and 8 deletions

View File

@ -810,14 +810,29 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
// -----
// CHECK-LABEL: func @torch.aten.select_scatter
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
// CHECK-NEXT: return %[[SLICE_SCATTER]]
// CHECK-NEXT: }
// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]]
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]]
// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]]
// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK-NEXT: %[[T4:.*]] = torch.aten.sub.Scalar %[[T3]], %[[INT0]], %[[INT1_1]]
// CHECK-NEXT: %[[T5:.*]] = torch.aten.remainder.Scalar %[[T4]], %[[INT1_0]]
// CHECK-NEXT: %[[T6:.*]] = torch.aten.eq.Scalar %[[T5]], %[[INT0_2]]
// CHECK-NEXT: %[[T7:.*]] = torch.aten.ge.Scalar %[[T4]], %[[INT0_2]]
// CHECK-NEXT: %[[T8:.*]] = torch.aten.ge.Scalar %[[T3]], %[[T0]]
// CHECK-NEXT: %[[T9:.*]] = torch.aten.bitwise_and.Tensor %[[T6]], %[[T7]]
// CHECK-NEXT: %[[T10:.*]] = torch.aten.bitwise_and.Tensor %[[T9]], %[[T8]]
// CHECK-NEXT: %[[T11:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]]
// CHECK-NEXT: %[[T12:.*]] = torch.aten.view %[[T10]], %[[T11]]
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]]
// CHECK-NEXT: return %[[T13]]
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1