[Torch] enhance fold of aten.squeeze.dim (#3558)

pull/3559/head
Yuanqiang Liu 2024-07-24 14:13:48 +08:00 committed by GitHub
parent d1e172f418
commit aad1604046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 22 deletions

View File

@ -128,6 +128,17 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}
static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr,
ShapedType newType) {
// TODO: DenseElementsAttr::reshape is broken for bool splats.
// Once that ticket is fixed, we can remove this conditional.
if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
auto splatValue = attr.getValues<bool>()[0];
return DenseElementsAttr::get(newType, {splatValue});
}
return attr.reshape(newType);
}
static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
@ -798,11 +809,22 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType())
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
if (!inType || !outType || !inType.areAllSizesKnown() ||
!outType.areAllSizesKnown() || !inType.hasDtype() ||
!outType.hasDtype()) {
return nullptr;
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
}
if (inType == outType) {
return getOperand(0);
}
DenseElementsAttr input =
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
if (input) {
return reshapeDenseElementsAttr(input, outType.toBuiltinTensor());
}
return nullptr;
}

View File

@ -379,15 +379,25 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt
// -----
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32>
// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32>
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32>
return %0 : !torch.vtensor<[2,1,2,1,2],f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3:.*]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
// CHECK: %[[T1:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<4xindex>) -> tensor<2x2x1x2xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<2x2x1x2xf32> -> !torch.vtensor<[2,2,1,2],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[2,2,1,2],f32>
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,2,1,2],f32>
return %0 : !torch.vtensor<[2,2,1,2],f32>
}
// -----

View File

@ -1507,20 +1507,67 @@ func.func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
}
// CHECK-LABEL: func.func @torch.aten.squeeze$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = torch.aten.squeeze %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[127, 128]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<[[127], [128]]> : tensor<2x1xsi64>) : !torch.vtensor<[2,1],si64>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
return %1 : !torch.vtensor<[2],si64>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, true]> : tensor<3xi1>) : !torch.vtensor<[3],i1>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<[[true], [false], [true]]> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
return %1 : !torch.vtensor<[3],i1>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<true> : tensor<3xi1>) : !torch.vtensor<[3],i1>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<true> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
return %1 : !torch.vtensor<[3],i1>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$same_shape(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
// CHECK-NEXT: return %[[ARG]]
func.func @torch.aten.squeeze.dim$same_shape(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],si64>
return %0 : !torch.vtensor<[2,1],si64>
}
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$not_fold
// CHECK: torch.aten.squeeze.dim
func.func @torch.aten.squeeze.dim$not_fold(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2],si64> {
%int1 = torch.constant.int 1
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}
// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(