torch-mlir/test/Dialect/Torch/scalarize-shapes.mlir

75 lines
3.9 KiB
MLIR

// RUN: torch-mlir-opt <%s --torch-scalarize-shapes -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: @shape_as_tensor
func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],si32> {
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[I2:.+]] = torch.constant.int 2
// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[I1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]]
// CHECK-DAG: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[I2]]
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I5]], %[[SZ1]], %[[SZ2]]
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]] : !torch.vtensor<[3],si32>
%0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
return %0 : !torch.vtensor<[3],si32>
}
// -----
// CHECK-LABEL: @shape_as_tensor_dim
func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> {
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[NONE:.+]] = torch.constant.none
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]]
// CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32>
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
return %select : !torch.vtensor<[],si32>
}
// -----
// CHECK-LABEL: @shape_as_tensor_dim_item
func.func @shape_as_tensor_dim_item(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.int {
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int
// CHECK: return %[[SZ]]
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32>
%dim = torch.constant.int 0
%idx = torch.vtensor.literal(dense<1> : tensor<si32>) : !torch.vtensor<[],si32>
%select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32>
%out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int
return %out : !torch.int
}
// -----
// CHECK-LABEL: @shape_as_tensor_slice
func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torch.vtensor<[2],si32> {
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[NONE:.+]] = torch.constant.none
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1]]
// CHECK-DAG: %[[SZ3:.+]] = torch.aten.size.int %arg0, %[[INT3]]
// CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[SZ1]], %[[SZ3]]
// CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK: return %[[TENSOR]]
%shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?,?],f32> -> !torch.vtensor<[4],si32>
%dim = torch.constant.int 0
%start = torch.constant.int 1
%end = torch.constant.int 5
%step = torch.constant.int 2
%slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32>
return %slice : !torch.vtensor<[2],si32>
}