mirror of https://github.com/llvm/torch-mlir
torch.aten.squeeze.dim lowering with dynamic dims (#3749)
Address https://github.com/nod-ai/SHARK-ModelDev/issues/846 Assume the dynamic squeezed dim is 1.pull/3776/head
parent
614fcdd153
commit
58489faf7f
|
@ -1658,10 +1658,17 @@ public:
|
|||
if (!isValidDim(dim, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||
|
||||
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
||||
// assert dynamic squeeze dim size == 1
|
||||
if (inputType.isDynamicDim(dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
||||
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
|
||||
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
|
||||
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
|
||||
Value cmp = rewriter.create<arith::CmpIOp>(
|
||||
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
op.getLoc(), cmp,
|
||||
rewriter.getStringAttr(
|
||||
"Expected dynamic squeeze dim size to be statically 1"));
|
||||
}
|
||||
|
||||
const TypeConverter *typeConverter = getTypeConverter();
|
||||
|
@ -1671,7 +1678,7 @@ public:
|
|||
|
||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||
// `aten.squeeze` will behave as an identity operation.
|
||||
if (inputType.getDimSize(dim) != 1) {
|
||||
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
|
||||
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
|
||||
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
|
||||
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
%int0 = torch.constant.int 0
|
||||
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
Loading…
Reference in New Issue