diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index f2e8086de..bca69906d 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { auto constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); return rewriter.create(op->getLoc(), constType, @@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, AtenLinalgVectorNormOp>(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - } else if (isa(op)) { + } else if (isa(op)) { result = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); } else if (isa(op)) { @@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( patterns.add>(typeConverter, context, \ options) INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp); + INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp); #undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN #define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e741d0de..53f1b3647 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -815,10 +815,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", "RandnModule_basic", - "ReduceAllDimBool_basic", - "ReduceAllDimEmpty_basic", - "ReduceAllDimFloat_basic", - "ReduceAllDimInt_basic", "ReduceProdDimIntFloatModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", @@ -836,18 +832,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ReplicationPad2dModule_top0", "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", - # need aten.all.dim lowering to stablehlo - "SafeSoftmaxModule_basic", - "SafeSoftmaxNonNoneDtypeModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED - "ScaledDotProductAttentionBoolMaskModule_basic", - "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", - "ScaledDotProductAttentionDifferentCausalModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionMaskModule_basic", - "ScaledDotProductAttentionSameCausalModule_basic", - "ScaledDotProductAttentionSameDynamicModule_basic", - "ScaledDotProductAttentionSameModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index e9b84ea06..89774c5d1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) +class ReduceAllDimFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.all(a, dim=0) + + +@register_test_case(module_factory=lambda: ReduceAllDimFloatModule()) +def ReduceAllDimFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + # ==============================================================================