From 71ac62f3a89f751a2750e922757789ff0cff489e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 29 Sep 2023 05:09:31 +0000 Subject: [PATCH] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2023-09-28. aten.baddbmm changes done because upstream PyTorch has now added support for fp16 gemm on CPU. Refer: https://github.com/pytorch/pytorch/commit/9399e0b1ff743d2c968196a3be129307ac360823 --- .../Transforms/AbstractInterpLibrary.cpp | 43 ++++++++----------- .../build_tools/abstract_interp_lib_gen.py | 6 +-- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index de99e529c..2fffbd313 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9950,39 +9950,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" -" %int5 = torch.constant.int 5\n" " %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %2 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %4 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" -" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" -" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %8 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %11 : !torch.int\n" +" %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %7 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index d142770b1..7c507f53b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2822,7 +2822,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) + [ErrorInvocation(TensorOfShape( 1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")), ErrorInvocation( @@ -2834,8 +2834,8 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int: batch1_rank, batch1_dtype = batch1_rank_dtype batch2_rank, batch2_dtype = batch2_rank_dtype - assert batch1_dtype not in [torch.bool, torch.float16] - assert batch2_dtype not in [torch.bool, torch.float16] + assert batch1_dtype is not torch.bool + assert batch2_dtype is not torch.bool assert batch1_dtype == batch2_dtype ranks: List[Optional[int]] = [batch1_rank, batch2_rank] dtypes = [batch1_dtype, batch2_dtype] diff --git a/pytorch-hash.txt b/pytorch-hash.txt index fb206a10a..c0a7e8b3e 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -d7520d8668dc08f7bed27a64f006c909006e653a +fecde478ac83edf78e7d0e9d11ab73cb1580f6cf diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 3dfe294db..4e3a9152f 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20230927 +torch==2.2.0.dev20230928 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index ae783e410..9c653bbcd 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20230927 +torchvision==0.17.0.dev20230928