build: manually update PyTorch version

Set PyTorch and TorchVision version to nightly release 2023-09-28.

aten.baddbmm changes done because upstream PyTorch has now added
support for fp16 gemm on CPU.
Refer: 9399e0b1ff
pull/2491/head
Vivek Khandelwal 2023-09-29 05:09:31 +00:00
parent 860be09a39
commit 71ac62f3a8
5 changed files with 25 additions and 30 deletions

View File

@ -9950,39 +9950,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int5 = torch.constant.int 5\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" %2 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %11 : !torch.int\n"
" %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %7 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -2822,7 +2822,7 @@ def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor
@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) +
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) +
[ErrorInvocation(TensorOfShape(
1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")),
ErrorInvocation(
@ -2834,8 +2834,8 @@ def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
def atenbaddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
batch1_rank, batch1_dtype = batch1_rank_dtype
batch2_rank, batch2_dtype = batch2_rank_dtype
assert batch1_dtype not in [torch.bool, torch.float16]
assert batch2_dtype not in [torch.bool, torch.float16]
assert batch1_dtype is not torch.bool
assert batch2_dtype is not torch.bool
assert batch1_dtype == batch2_dtype
ranks: List[Optional[int]] = [batch1_rank, batch2_rank]
dtypes = [batch1_dtype, batch2_dtype]

View File

@ -1 +1 @@
d7520d8668dc08f7bed27a64f006c909006e653a
fecde478ac83edf78e7d0e9d11ab73cb1580f6cf

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torch==2.2.0.dev20230927
torch==2.2.0.dev20230928

View File

@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
torchvision==0.17.0.dev20230927
torchvision==0.17.0.dev20230928