mirror of https://github.com/llvm/torch-mlir
torch.aten.squeeze.dim lowering with dynamic dims (#3749)
Address https://github.com/nod-ai/SHARK-ModelDev/issues/846 Assume the dynamic squeezed dim is 1.pull/3776/head
parent
614fcdd153
commit
58489faf7f
|
@ -1658,10 +1658,17 @@ public:
|
||||||
if (!isValidDim(dim, inputRank))
|
if (!isValidDim(dim, inputRank))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
||||||
// TODO: Handle the case where the dim(th) dimension is dynamic.
|
// assert dynamic squeeze dim size == 1
|
||||||
if (inputType.isDynamicDim(dim)) {
|
if (inputType.isDynamicDim(dim)) {
|
||||||
return rewriter.notifyMatchFailure(
|
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
|
||||||
op, "unimplemented: dim(th) dimension is not expected to be dynamic");
|
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
|
||||||
|
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
|
||||||
|
Value cmp = rewriter.create<arith::CmpIOp>(
|
||||||
|
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
|
||||||
|
rewriter.create<cf::AssertOp>(
|
||||||
|
op.getLoc(), cmp,
|
||||||
|
rewriter.getStringAttr(
|
||||||
|
"Expected dynamic squeeze dim size to be statically 1"));
|
||||||
}
|
}
|
||||||
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
@ -1671,7 +1678,7 @@ public:
|
||||||
|
|
||||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||||
// `aten.squeeze` will behave as an identity operation.
|
// `aten.squeeze` will behave as an identity operation.
|
||||||
if (inputType.getDimSize(dim) != 1) {
|
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic
|
||||||
|
func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} {
|
||||||
|
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
// CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index
|
||||||
|
// CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1"
|
||||||
|
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %1 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue