diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 181a13fb8..74c2aedcd 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -90,6 +90,12 @@ struct onnx_list_of_constant_ints_op_binder { } return true; } + if (ElementsAttr attr = dyn_cast_or_null( + constOp->getAttr("torch.onnx.value"))) { + for (auto axis : attr.getValues()) + bind_values.push_back(axis.getSExtValue()); + return true; + } return false; } }; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 4ef44c968..ce80527dc 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -505,6 +505,18 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_unsqueeze_dyn_dims +func.func @test_unsqueeze_dyn_dims(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} { + // CHECK: %[[x0:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[x1:.*]] = torch.aten.unsqueeze %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %1 = torch.operator "onnx.Unsqueeze"(%arg0, %0) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,?],f32> + return %1 : !torch.vtensor<[?,1,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0