diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2f276b1a2..6207e753e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2593,16 +2593,22 @@ public: // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. if (isa(dim.getType())) { - BaseTensorType flattenType = - cast(inputType.getWithSizesAndDtype( - {kUnknownSize}, inputType.getOptionalDtype())); - Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value zero = rewriter.create(loc, 0); Value falseValue = rewriter.create(loc, false); - input = rewriter.create(loc, flattenType, input, - zero, end); + if (inputType.getSizes().size() > 1) { + int64_t flattenSize = Torch::kUnknownSize; + if (inputType.areAllSizesKnown()) { + flattenSize = 1; + for (int64_t sze : inputType.getSizes()) + flattenSize *= sze; + } + auto flattenType = cast(inputType.getWithSizesAndDtype( + {flattenSize}, inputType.getOptionalDtype())); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + input = rewriter.create(loc, flattenType, input, + zero, end); + } Value resultIndices = rewriter .create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e0011b9a3..e8bdda1e6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -545,10 +545,6 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AddFloatIntModule_basic", - "ArgmaxIntModule_basic", - "ArgmaxIntModule_multiple_maxs", - "ArgmaxKeepdimModule_basic", - "ArgmaxModule_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64Module_basic", @@ -618,9 +614,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", - "ArgminIntModule_basic", - "ArgminIntModule_multiple_mins", - "ArgminModule_basic", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 4da482af0..c29635de6 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -25,6 +25,19 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch return %0 : !torch.tensor } +// ----- +// CHECK-LABEL: func.func @argmax_rank_1 +// CHECK: %[[I0:.*]] = torch.constant.int 0 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64> +// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64> +func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> { + %none = torch.constant.none + %false = torch.constant.bool false + %7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64> + return %7 : !torch.vtensor<[],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {