diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eaa981e94..6004f6b4a 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -455,6 +455,7 @@ TOSA_PASS_SET = { "ArgmaxModule_with_dim", "_LogSoftmaxModuleStable_basic", "ElementwiseAtenWhereSelfModule_basic", + "ElementwiseUnsqueezeBroadcastModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 8a01c77bb..c9b379e10 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2435,7 +2435,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); dim = toPositiveDim(dim, selfRank); - if (!isValidDim(dim, selfRank)) + if (!isValidDim(dim, selfRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector outShape; diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index cfa12f522..f788deefb 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -690,18 +690,18 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // ----- // CHECK-LABEL: func.func @torch.aten.unsqueeze$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 1, 3]} : (tensor<4x3xi32>) -> tensor<4x1x3xi32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x1x3xi32> -> !torch.vtensor<[4,1,3],si32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 3]} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } -func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32> - return %0 : !torch.vtensor<[4,1,3],si32> +func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,3,1],si32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.unsqueeze %arg0, %int2 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,3,1],si32> + return %0 : !torch.vtensor<[4,3,1],si32> } // -----