Fix handling of `!torch.number` in abstract interpretation library (#2309)

In PyTorch, the `NumberType` is equal to `Union[int, float,
complex]`. However, the abstract interpretation library was treating
the `NumberType` as `Union[int, float]`, resulting in type mismatches
when reifying certain dtype functions. This commit fixes the type
inconsistency by having the abstract interpretation functions take as
an input a `Union[int, float, complex]` for the ops that take
`!torch.number` inputs.
pull/2318/head
Ramiro Leal-Cavazos 2023-07-17 09:52:04 -07:00 committed by GitHub
parent 5706697e0b
commit 718f53ff8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 159 additions and 137 deletions

View File

@ -7287,7 +7287,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.number, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
@ -8026,7 +8026,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
@ -8166,7 +8166,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -8178,7 +8178,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -8190,7 +8190,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int, none>, %arg2: !torch.union<float, int, none>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<number>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -8206,7 +8206,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
@ -8263,7 +8263,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
@ -8315,7 +8315,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
@ -8326,7 +8326,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
@ -8381,7 +8381,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>, %arg3: !torch.bool) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
@ -8396,11 +8396,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
@ -8544,7 +8544,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
@ -8591,7 +8591,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
@ -8647,12 +8647,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.number) -> !torch.int {\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.number -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
@ -8710,7 +8710,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
@ -8718,11 +8718,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
@ -8734,7 +8734,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
@ -8758,7 +8758,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
@ -8778,15 +8778,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
@ -8835,11 +8835,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
@ -8852,7 +8852,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
@ -9176,7 +9176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
@ -9184,7 +9184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n"
" %int11 = torch.constant.int 11\n"
" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
@ -9358,7 +9358,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>, %arg4: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -9376,7 +9376,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
@ -9409,7 +9409,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %8 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -9425,39 +9425,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %7 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
@ -9468,16 +9468,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
@ -9490,21 +9490,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
@ -9517,7 +9517,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
@ -9536,16 +9536,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>, %arg4: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int5 = torch.constant.int 5\n"
@ -9590,14 +9590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int4 = torch.constant.int 4\n"
" %false = torch.constant.bool false\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %5 : !torch.bool\n"
" } else {\n"
@ -9610,20 +9610,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
@ -9701,7 +9701,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.number, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int6 = torch.constant.int 6\n"
" %str = torch.constant.str \"AssertionError: \"\n"
@ -9719,7 +9719,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
@ -9730,7 +9730,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int6 = torch.constant.int 6\n"
" %true = torch.constant.bool true\n"
@ -9749,12 +9749,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
@ -9767,7 +9767,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %int6 = torch.constant.int 6\n"
" %true = torch.constant.bool true\n"
@ -9786,19 +9786,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %8 : !torch.bool\n"
" }\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %8 : !torch.bool\n"
" }\n"
@ -9910,7 +9910,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" return %0 : !torch.int\n"
@ -9925,7 +9925,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" return %0 : !torch.int\n"
@ -9935,7 +9935,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
@ -10061,7 +10061,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
@ -10069,7 +10069,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" } else {\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
@ -10116,7 +10116,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
@ -10313,7 +10313,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
" %int6 = torch.constant.int 6\n"
@ -10485,8 +10485,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.number) -> !torch.int {\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"

View File

@ -176,10 +176,17 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
// be resolved to an `int` or a `float` so we need to derefine to match the
// library function signature.
// The type `!torch.number` can be an `int`, `float`, or `complex`.
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
if (desiredType.isa<Torch::NumberType>() &&
operandType.isa<Torch::IntType, Torch::FloatType>()) {
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// !torch.union<int, float, none> is the type used for optional
// `Scalar` inputs. At compile time, such inputs will usually be
// resolved to an `int`, `float`, or `None` so we need to derefine
// to match the library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType

View File

@ -772,7 +772,7 @@ def atenscalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout:
return []
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
def atenscalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
def atenscalar_tensor〡dtype(s: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
return dtype
else:
@ -1314,7 +1314,7 @@ def atenerf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensoftplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
def atensoftplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, threshold: Union[int, float, complex] = 20) -> int:
self_rank, self_dtype = self_rank_dtype
if is_integer_dtype(self_dtype):
return self_dtype
@ -1395,21 +1395,21 @@ def atenceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
def atenclamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int:
def atenclamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.bool:
return torch.int64
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0))
def atenclamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int:
def atenclamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.bool:
return torch.int64
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1))
def atenclamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int:
def atenclamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float, complex]] = None, max: Optional[Union[int, float, complex]] = None) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.bool:
return torch.int64
@ -1421,7 +1421,7 @@ def atenclone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Option
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1]))
def atenconstant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int:
def atenconstant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float, complex] = 0) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1478,7 +1478,7 @@ def atenexpand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], imp
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0))
def atenfillScalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
def atenfillScalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1537,14 +1537,14 @@ def atenhardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return self_dtype
@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5))
def atenhardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int:
def atenhardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex], max_val: Union[int, float, complex]) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
if is_integer_dtype(grad_output_dtype):
return torch.float32
return grad_output_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool}))
def atenhardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int:
def atenhardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex] = -1, max_val: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype not in [torch.uint8, torch.bool]
return self_dtype
@ -1597,7 +1597,7 @@ def atenlayer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap
return input_dtype
@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False))
def atenleaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int:
def atenleaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
@ -1617,12 +1617,12 @@ def aten_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int,
return input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))
def atenmasked_fillScalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
def atenmasked_fillScalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))
def atenmasked_fill_Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
def atenmasked_fill_Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1766,7 +1766,7 @@ def atenscattersrc〡dtype(self_rank_dtype: Tuple[int, int], dim: int, ind
@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES])
def atenscattervalue〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
def atenscattervalue〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1820,7 +1820,7 @@ def atentanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output
return promoted_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0))
def atenthreshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int:
def atenthreshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex], value: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@ -1890,7 +1890,7 @@ def atenzero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return self_dtype
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
def primabsScalar〡dtype(a: Union[int, float]) -> int:
def primabsScalar〡dtype(a: Union[int, float, complex]) -> int:
return get_dtype_of_scalar(a)
@check_dtype_function(_check_tensors_with_the_same_dtype(
@ -1931,7 +1931,7 @@ def atenany〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def ateneqScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def ateneqScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
@ -1941,13 +1941,13 @@ def ateneqTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atengeScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atengeScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atengtScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atengtScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
@ -1961,7 +1961,7 @@ def atengeTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atenleScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenleScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
@ -1988,7 +1988,7 @@ def atenlogical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atenltScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenltScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
@ -2010,7 +2010,7 @@ def atenneTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atenneScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenneScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
return torch.bool
@check_dtype_function([
@ -2019,7 +2019,7 @@ def atenneScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[in
Invocation(0, 0.0), # int, float
Invocation(0, 0), # int, int
])
def atenadd〡dtype(a: Union[int, float], b: Union[int, float]) -> int:
def atenadd〡dtype(a: Union[int, float, complex], b: Union[int, float, complex]) -> int:
ranks: List[Optional[int]] = [None, None]
dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)]
return promote_dtypes(ranks, dtypes)
@ -2044,7 +2044,7 @@ def atenfft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] =
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
def atenrsubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
def atenrsubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
@ -2057,7 +2057,7 @@ def aten__and__Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenaddTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int:
def atenaddTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
@ -2238,7 +2238,7 @@ def atenmv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[in
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atensubTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int:
def atensubTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
@ -2249,7 +2249,7 @@ def atensubTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
# https://github.com/pytorch/pytorch/issues/100921
# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor
@check_dtype_function(_check_two_tensor_op(tensor_device="cpu", input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}, threshold=0))
def atenthreshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int:
def atenthreshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex"
@ -2433,7 +2433,7 @@ def atenbincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
TensorOfShape(3, 4, dtype=torch.float32),
TensorOfShape(4, 3, dtype=torch.float32))])
def atenaddmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int:
def atenaddmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
mat1_rank, mat1_dtype = mat1_rank_dtype
mat2_rank, mat2_dtype = mat2_rank_dtype
@ -2477,7 +2477,7 @@ def atenlerpTensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
TensorOfShape(3, 3, dtype=torch.float32),
TensorOfShape(3, 3, dtype=torch.float32))])
def atenaddcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int:
def atenaddcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
tensor1_rank, tensor1_dtype = tensor1_rank_dtype
tensor2_rank, tensor2_dtype = tensor2_rank_dtype
@ -2503,7 +2503,7 @@ def atenaddcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
TensorOfShape(3, 3, dtype=torch.float32),
TensorOfShape(3, 3, dtype=torch.float32))])
def atenaddcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int:
def atenaddcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
tensor1_rank, tensor1_dtype = tensor1_rank_dtype
tensor2_rank, tensor2_dtype = tensor2_rank_dtype
@ -2517,7 +2517,7 @@ def atenaddcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atenaddScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
def atenaddScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2526,7 +2526,7 @@ def atenaddScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atensubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
def atensubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2534,7 +2534,7 @@ def atensubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atenmulScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenmulScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2542,7 +2542,7 @@ def atenmulScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atendivScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atendivScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2554,7 +2554,7 @@ def atendivScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atenfmodScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenfmodScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2563,7 +2563,7 @@ def atenfmodScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_complex_dtype(self_dtype)
ranks: List[Optional[int]] = [self_rank, None]
@ -2572,7 +2572,7 @@ def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0))
def atenpowTensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int:
def atenpowTensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(exponent)]
@ -2581,7 +2581,7 @@ def atenpowTensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponen
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0))
def atenleaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int:
def atenleaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex] = 0.01) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.bool
ranks: List[Optional[int]] = [self_rank, None]
@ -2594,7 +2594,7 @@ def atenleaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope:
@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2611,7 +2611,7 @@ def atenremainderScalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu")),
ErrorInvocation(
TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.bfloat16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"))])
def atenbaddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int:
def atenbaddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
batch1_rank, batch1_dtype = batch1_rank_dtype
batch2_rank, batch2_dtype = batch2_rank_dtype
assert batch1_dtype not in [torch.bool, torch.float16]
@ -2637,7 +2637,7 @@ def atenwhereself〡dtype(condition_rank_dtype: Tuple[int, int], self_rank
Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0),
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0),
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)])
def atenwhereScalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int:
def atenwhereScalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other: Union[int, float, complex]) -> int:
if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)):
return torch.int64
return torch.float32
@ -2646,7 +2646,7 @@ def atenwhereScalar〡dtype(condition_rank_dtype: Tuple[int, int], self: U
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0),
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0),
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)])
def atenwhereScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
def atenwhereScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, None]
dtypes = [self_dtype, get_dtype_of_scalar(other)]
@ -2656,7 +2656,7 @@ def atenwhereScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], se
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)),
Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)),
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))])
def atenwhereScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int:
def atenwhereScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
ranks: List[Optional[int]] = [None, other_rank]
dtypes = [get_dtype_of_scalar(self), other_dtype]
@ -2755,7 +2755,7 @@ def atennative_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified
Invocation(end=0, dtype=torch.float16), # Dtype specified
Invocation(end=0, dtype=torch.int16)]) # Dtype specified
def atenarange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
def atenarange〡dtype(end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
assert not is_complex_dtype(dtype)
return dtype
@ -2769,7 +2769,7 @@ def atenarange〡dtype(end: Union[int, float], dtype: Optional[int] = None, l
ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified
Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified
Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified
def atenarangestart〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
def atenarangestart〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
assert not is_complex_dtype(dtype)
return dtype
@ -2785,7 +2785,7 @@ def atenarangestart〡dtype(start: Union[int, float], end: Union[int, floa
ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified
Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified
Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified
def atenarangestart_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
def atenarangestart_step〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], step: Union[int, float, complex] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
assert not is_complex_dtype(dtype)
return dtype
@ -2876,7 +2876,7 @@ def atenstddim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenstdcorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int:
def atenstdcorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int:
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
@ -2888,7 +2888,7 @@ def atenvardim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenvarcorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int:
def atenvarcorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int:
return atenstd〡dtype(self_rank_dtype)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0))
@ -2906,7 +2906,7 @@ def primsvar〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) +
[ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)])
def atenlinalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
def atenlinalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float, complex] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
if dtype is not None:
@ -2971,7 +2971,7 @@ def atenemptymemory_format〡dtype(size: List[int], dtype: Optional[int] =
Invocation([1], 0.0, dtype=torch.int32),
Invocation([1], 0.0, dtype=torch.float16),
Invocation([1], 0.0, dtype=torch.complex64)])
def atenfull〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
def atenfull〡dtype(size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is not None:
return dtype
fill_value_dtype = get_dtype_of_scalar(fill_value)
@ -3009,7 +3009,7 @@ def atenempty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64))
def atenfull_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int:
def atenfull_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype
@ -3143,7 +3143,7 @@ def atenrandngenerator〡dtype(size: List[int], generator: Any, dtype: Opt
return dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes()))
def atenvar_meancorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]:
def atenvar_meancorrection〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
if self_dtype == torch.complex64:
@ -3220,7 +3220,7 @@ def atenScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
assert False, "Unexpected dtype!"
@check_dtype_function([Invocation(0), Invocation(0.0)])
def primNumToTensorScalar〡dtype(a: Union[int, float]) -> int:
def primNumToTensorScalar〡dtype(a: Union[int, float, complex]) -> int:
return get_dtype_of_scalar(a)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

View File

@ -63,7 +63,7 @@ def get_priority_of_dtype(dtype: int) -> int:
return 11
assert False, "Cannot determine priority of dtype"
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
# that when `jit.script`ed converts a float scalar to a tensor
# with dtype that corresponds to Python's `float`.

View File

@ -49,7 +49,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
def _pytype_to_fn_pytype_common(pytype: str) -> str:
if "number" in pytype:
return pytype.replace("number", "Union[int, float]")
return pytype.replace("number", "Union[int, float, complex]")
# `torch.device` is lowercase.
if pytype == "Device":
return "device"

View File

@ -72,3 +72,18 @@ func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return %0 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.arange(
// CHECK-LABEL: func.func @derefine_int_to_number() -> !torch.vtensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[NUMBER:.*]] = torch.derefine %[[INT1]] : !torch.int to !torch.number
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.arange(%[[NUMBER]], {{.*}}) : (!torch.number, {{.*}}) -> !torch.int
func.func @derefine_int_to_number() -> !torch.vtensor {
%int1 = torch.constant.int 1
%none = torch.constant.none
%0 = torch.aten.arange %int1, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %0 : !torch.vtensor
}