torch.aten.squeeze.dim lowering with dynamic dims (#3749)

Address https://github.com/nod-ai/SHARK-ModelDev/issues/846

Assume the dynamic squeezed dim is 1.
pull/3776/head
jinchen 2024-10-08 10:37:31 -07:00 committed by GitHub
parent 614fcdd153
commit 58489faf7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 4 deletions

View File

@ -1658,10 +1658,17 @@ public:
if (!isValidDim(dim, inputRank)) if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid"); return rewriter.notifyMatchFailure(op, "dim is statically invalid");
// TODO: Handle the case where the dim(th) dimension is dynamic. // assert dynamic squeeze dim size == 1
if (inputType.isDynamicDim(dim)) { if (inputType.isDynamicDim(dim)) {
return rewriter.notifyMatchFailure( Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
op, "unimplemented: dim(th) dimension is not expected to be dynamic"); Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
Value cmp = rewriter.create<arith::CmpIOp>(
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
rewriter.create<cf::AssertOp>(
op.getLoc(), cmp,
rewriter.getStringAttr(
"Expected dynamic squeeze dim size to be statically 1"));
} }
const TypeConverter *typeConverter = getTypeConverter(); const TypeConverter *typeConverter = getTypeConverter();
@ -1671,7 +1678,7 @@ public:
// If the dim(th) dimension of operand tensor type is not statically unit, // If the dim(th) dimension of operand tensor type is not statically unit,
// `aten.squeeze` will behave as an identity operation. // `aten.squeeze` will behave as an identity operation.
if (inputType.getDimSize(dim) != 1) { if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
return success(); return success();
} }

View File

@ -0,0 +1,17 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}