mirror of https://github.com/llvm/torch-mlir
[Torch] enhance fold of aten.squeeze.dim (#3558)
parent
d1e172f418
commit
aad1604046
|
@ -128,6 +128,17 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
|
|||
return FloatAttr::get(Float64Type::get(context), value);
|
||||
}
|
||||
|
||||
static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr,
|
||||
ShapedType newType) {
|
||||
// TODO: DenseElementsAttr::reshape is broken for bool splats.
|
||||
// Once that ticket is fixed, we can remove this conditional.
|
||||
if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
|
||||
auto splatValue = attr.getValues<bool>()[0];
|
||||
return DenseElementsAttr::get(newType, {splatValue});
|
||||
}
|
||||
return attr.reshape(newType);
|
||||
}
|
||||
|
||||
static Value getScalarIntValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputType = input.getType();
|
||||
|
@ -798,11 +809,22 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||
if (getOperand(0).getType() != getResult().getType())
|
||||
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
||||
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
||||
if (!inType || !outType || !inType.areAllSizesKnown() ||
|
||||
!outType.areAllSizesKnown() || !inType.hasDtype() ||
|
||||
!outType.hasDtype()) {
|
||||
return nullptr;
|
||||
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
|
||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
if (inType == outType) {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
DenseElementsAttr input =
|
||||
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
|
||||
if (input) {
|
||||
return reshapeDenseElementsAttr(input, outType.toBuiltinTensor());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -379,15 +379,25 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32>
|
||||
// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32>
|
||||
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32>
|
||||
return %0 : !torch.vtensor<[2,1,2,1,2],f32>
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3:.*]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[C4:.*]] = arith.constant 4 : index
|
||||
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
|
||||
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
|
||||
// CHECK: %[[T1:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<4xindex>) -> tensor<2x2x1x2xf32>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<2x2x1x2xf32> -> !torch.vtensor<[2,2,1,2],f32>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[2,2,1,2],f32>
|
||||
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,2,1,2],f32>
|
||||
return %0 : !torch.vtensor<[2,2,1,2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -1507,20 +1507,67 @@ func.func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
|
|||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze$zero_rank(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.aten.squeeze %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$zero_rank(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
|
||||
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
|
||||
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
|
||||
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[127, 128]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
|
||||
// CHECK-NEXT: return %[[CST]]
|
||||
func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<[[127], [128]]> : tensor<2x1xsi64>) : !torch.vtensor<[2,1],si64>
|
||||
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
|
||||
return %1 : !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
|
||||
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, true]> : tensor<3xi1>) : !torch.vtensor<[3],i1>
|
||||
// CHECK-NEXT: return %[[CST]]
|
||||
func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<[[true], [false], [true]]> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
|
||||
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
|
||||
return %1 : !torch.vtensor<[3],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
|
||||
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<true> : tensor<3xi1>) : !torch.vtensor<[3],i1>
|
||||
// CHECK-NEXT: return %[[CST]]
|
||||
func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.vtensor.literal(dense<true> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
|
||||
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
|
||||
return %1 : !torch.vtensor<[3],i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$same_shape(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
|
||||
// CHECK-NEXT: return %[[ARG]]
|
||||
func.func @torch.aten.squeeze.dim$same_shape(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],si64>
|
||||
return %0 : !torch.vtensor<[2,1],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.squeeze.dim$not_fold
|
||||
// CHECK: torch.aten.squeeze.dim
|
||||
func.func @torch.aten.squeeze.dim$not_fold(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2],si64> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
|
||||
return %0 : !torch.vtensor<[2],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(
|
||||
|
|
Loading…
Reference in New Issue