From 16b3bd6e6c8fbf166aad51911ef3fb24e7c96858 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 30 Oct 2024 18:56:01 +0530 Subject: [PATCH] build: manually update PyTorch version and fix CI failure (#3830) This commit sets the PyTorch and TorchVision version to nightly release 2024-10-29. This commit also fixes the CI failure after this commit https://github.com/llvm/torch-mlir/commit/54d9e2401376e7eb2c6c219e3b3555f45f8b2635 got merged. The issue was that the CI checks in the PR were run before the previous roll pytorch update but the PR was actually merged after the roll pytorch update. Hence, the failure was not caught before merging the PR. While exporting the fx_graph through fx_importer for `rrelu` and `rrelu_with_noise` op for train mode, it decomposes the `aten.rrelu_with_noise` op based on the PyTorch decomposition which is the default behavior. However, the decomposition contains an input mutation specifically here https://github.com/pytorch/pytorch/blob/9bbe4a67ad137032add6a3b0b74bda66f5ef83d2/torch/_decomp/decompositions.py#L325, resulting in the runtime failure. This issue would probably be fixed by https://github.com/pytorch/pytorch/pull/138503. Until then, the failing tests are added to the xfail set. Also, after the roll pytorch update following tests started passing for fx_importer, and fx_importer_stablehlo config. - "ElementwiseRreluTrainModule_basic" - "ElementwiseRreluTrainStaticModule_basic" - "ElementwiseRreluWithNoiseTrainModule_basic" - "ElementwiseRreluWithNoiseTrainStaticModule_basic" This commit also updates the dtype check for the `aten.linear` op since the op now expects both the input tensors to have the same dtype. Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 18 ++++++++++-------- .../build_tools/abstract_interp_lib_gen.py | 2 +- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5686664d3..3881aa145 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -420,7 +420,6 @@ FX_IMPORTER_XFAIL_SET = { "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", @@ -446,8 +445,6 @@ FX_IMPORTER_XFAIL_SET = { "NllLossModuleBackward1DSum_basic", "NllLossModuleBackward1DWeight_basic", "NllLossModuleBackward1D_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "PowIntFloatModule_basic", @@ -464,7 +461,6 @@ FX_IMPORTER_XFAIL_SET = { "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", "SortIntListReverse_basic", "SortIntList_basic", @@ -523,6 +519,11 @@ FX_IMPORTER_XFAIL_SET = { "MeshgridIndexingXY_basic", "Meshgrid_basic", "OneHotModule_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -690,7 +691,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", @@ -792,8 +792,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormalFunctionalModule_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", "PowIntFloatModule_basic", @@ -829,7 +827,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "RsubInt0d_NumToTensor_Module_basic", "ScalarImplicitFloatModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterReduceFloatMaxModule", @@ -964,6 +961,11 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index d9e57d674..36ab8fe2c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5371,7 +5371,7 @@ def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype -@check_dtype_function(_check_two_tensor_op()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype diff --git a/pytorch-hash.txt b/pytorch-hash.txt index f9e0abfab..dd4f3a19a 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -160d421a40e934ac8183e47f9cbc8618a4bd97dd +c787213d413e85c66bdad0d8c9cde1c5ced34b1b diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index ca065711a..960ca904e 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241020 +torch==2.6.0.dev20241029 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 608d687cb..901fbd3d9 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.20.0.dev20241020 +torchvision==0.20.0.dev20241029