diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 9bef7de7a..fa1b7cc53 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1817,6 +1817,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenRemainderTensorOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRemainderTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + Value rhs = adaptor.getOther(); + + auto resultType = + cast(getTypeConverter()->convertType(op.getType())); + lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); + rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); + rewriter.replaceOpWithNewOp(op, lhs, rhs); + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -1959,6 +1975,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRemainderTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 037771b74..ce33b10ca 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -664,9 +664,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseRemainderTensorModule_Float_basic", - "ElementwiseRemainderTensorModule_Int_Float_basic", - "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtIntModule_basic", "ElementwiseSigmoidIntModule_basic", "ElementwiseSinIntModule_basic", @@ -1074,6 +1071,9 @@ STABLEHLO_PASS_SET = { "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic",