mirror of https://github.com/llvm/torch-mlir
build: manually update PyTorch version
Set PyTorch and TorchVision version to nightly release 2023-09-28.
aten.baddbmm changes done because upstream PyTorch has now added
support for fp16 gemm on CPU.
Refer: 9399e0b1ff
pull/2491/head
parent
860be09a39
commit
71ac62f3a8
|
@ -9950,39 +9950,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
|
||||
" %2 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %7 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %8 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %11 : !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %6 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %7 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%5, %6) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %7 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -2822,7 +2822,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
|
|||
|
||||
# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) +
|
||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) +
|
||||
[ErrorInvocation(TensorOfShape(
|
||||
1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")),
|
||||
ErrorInvocation(
|
||||
|
@ -2834,8 +2834,8 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
|
|||
def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
|
||||
batch1_rank, batch1_dtype = batch1_rank_dtype
|
||||
batch2_rank, batch2_dtype = batch2_rank_dtype
|
||||
assert batch1_dtype not in [torch.bool, torch.float16]
|
||||
assert batch2_dtype not in [torch.bool, torch.float16]
|
||||
assert batch1_dtype is not torch.bool
|
||||
assert batch2_dtype is not torch.bool
|
||||
assert batch1_dtype == batch2_dtype
|
||||
ranks: List[Optional[int]] = [batch1_rank, batch2_rank]
|
||||
dtypes = [batch1_dtype, batch2_dtype]
|
||||
|
|
|
@ -1 +1 @@
|
|||
d7520d8668dc08f7bed27a64f006c909006e653a
|
||||
fecde478ac83edf78e7d0e9d11ab73cb1580f6cf
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torch==2.2.0.dev20230927
|
||||
torch==2.2.0.dev20230928
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
torchvision==0.17.0.dev20230927
|
||||
torchvision==0.17.0.dev20230928
|
||||
|
|
Loading…
Reference in New Issue