diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index a7bdddbc8..4d3711e2e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1153,4 +1153,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, tensor, slope); return success(); }); + patterns.onOp("Mod", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self, other; + int64_t fmod; + if (binder.tensorOperands(self, other) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(fmod, "fmod", 0)) { + return failure(); + } + + if (fmod) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + }); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fbd0f5bb6..6c6666a28 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1611,8 +1611,6 @@ ONNX_XFAIL_SET = { "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseRemainderTensorModule_Int_basic", - "ElementwiseFmodTensor_Float_basic", - "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -1908,14 +1906,6 @@ ONNX_XFAIL_SET = { "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - # Failure - onnx_lowering: onnx.Mod - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", - # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 9dceff316..a348f40e3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -629,6 +629,24 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_mod_int64_fmod +func.func @test_mod_int64_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) {torch.onnx.fmod = 1 : si64} : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_mod_int64_no_fmod +func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + // CHECK-LABEL: func.func @test_log func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>