From aad16040463a6699b634756d94232ea1502d85e6 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 24 Jul 2024 14:13:48 +0800 Subject: [PATCH] [Torch] enhance fold of aten.squeeze.dim (#3558) --- lib/Dialect/Torch/IR/TorchOps.cpp | 30 +++++++-- .../TorchToStablehlo/view_like.mlir | 26 ++++--- test/Dialect/Torch/canonicalize.mlir | 67 ++++++++++++++++--- 3 files changed, 101 insertions(+), 22 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 0b5617440..66a027909 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -128,6 +128,17 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } +static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr, + ShapedType newType) { + // TODO: DenseElementsAttr::reshape is broken for bool splats. + // Once that ticket is fixed, we can remove this conditional. + if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) { + auto splatValue = attr.getValues()[0]; + return DenseElementsAttr::get(newType, {splatValue}); + } + return attr.reshape(newType); +} + static Value getScalarIntValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); @@ -798,11 +809,22 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { - if (getOperand(0).getType() != getResult().getType()) + auto inType = dyn_cast(getOperand(0).getType()); + auto outType = dyn_cast(getResult().getType()); + if (!inType || !outType || !inType.areAllSizesKnown() || + !outType.areAllSizesKnown() || !inType.hasDtype() || + !outType.hasDtype()) { return nullptr; - if (auto tensorType = dyn_cast(getOperand(0).getType())) { - if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) - return getOperand(0); + } + + if (inType == outType) { + return getOperand(0); + } + + DenseElementsAttr input = + dyn_cast_or_null(adaptor.getSelf()); + if (input) { + return reshapeDenseElementsAttr(input, outType.toBuiltinTensor()); } return nullptr; } diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index 2de800804..5e08f2d16 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -379,15 +379,25 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> -// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> -func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { - %int0 = torch.constant.int 0 - %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> - return %0 : !torch.vtensor<[2,1,2,1,2],f32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3:.*]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T1:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<4xindex>) -> tensor<2x2x1x2xf32> +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<2x2x1x2xf32> -> !torch.vtensor<[2,2,1,2],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[2,2,1,2],f32> +func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,2,1,2],f32> + return %0 : !torch.vtensor<[2,2,1,2],f32> } // ----- diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index f0b8ff3e8..0bb4455f2 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1507,20 +1507,67 @@ func.func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float { } // CHECK-LABEL: func.func @torch.aten.squeeze$zero_rank( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32> -func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { - %0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> - return %0 : !torch.tensor<[],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32> +func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> } // CHECK-LABEL: func.func @torch.aten.squeeze.dim$zero_rank( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32> -func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32> +func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { %int0 = torch.constant.int 0 - %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32> - return %0 : !torch.tensor<[],f32> + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[127, 128]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[[127], [128]]> : tensor<2x1xsi64>) : !torch.vtensor<[2,1],si64> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64> + return %1 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, true]> : tensor<3xi1>) : !torch.vtensor<[3],i1> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[[true], [false], [true]]> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1> + return %1 : !torch.vtensor<[3],i1> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense : tensor<3xi1>) : !torch.vtensor<[3],i1> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1> + return %1 : !torch.vtensor<[3],i1> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$same_shape( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> { +// CHECK-NEXT: return %[[ARG]] +func.func @torch.aten.squeeze.dim$same_shape(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$not_fold +// CHECK: torch.aten.squeeze.dim +func.func @torch.aten.squeeze.dim$not_fold(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> } // CHECK-LABEL: func.func @torch.aten.tensor$one_elem(