mirror of https://github.com/llvm/torch-mlir
Fix handling of `!torch.number` in abstract interpretation library (#2309)
In PyTorch, the `NumberType` is equal to `Union[int, float, complex]`. However, the abstract interpretation library was treating the `NumberType` as `Union[int, float]`, resulting in type mismatches when reifying certain dtype functions. This commit fixes the type inconsistency by having the abstract interpretation functions take as an input a `Union[int, float, complex]` for the ops that take `!torch.number` inputs.pull/2318/head
parent
5706697e0b
commit
718f53ff8a
|
@ -7287,7 +7287,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scalar_tensor\"(%arg0: !torch.number, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
|
@ -8026,7 +8026,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
|
@ -8166,7 +8166,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -8178,7 +8178,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -8190,7 +8190,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int, none>, %arg2: !torch.union<float, int, none>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<number>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -8206,7 +8206,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8263,7 +8263,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8315,7 +8315,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
|
@ -8326,7 +8326,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
@ -8381,7 +8381,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
|
@ -8396,11 +8396,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8544,7 +8544,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8591,7 +8591,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8647,12 +8647,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.number) -> !torch.int {\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
|
||||
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.number) -> !torch.int {\n"
|
||||
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.number -> !torch.tensor\n"
|
||||
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8710,7 +8710,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8718,11 +8718,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8734,7 +8734,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8758,7 +8758,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
|
@ -8778,15 +8778,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
|
@ -8835,11 +8835,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
|
@ -8852,7 +8852,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
|
@ -9176,7 +9176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
|
@ -9184,7 +9184,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n"
|
||||
|
@ -9358,7 +9358,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>, %arg4: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -9376,7 +9376,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
@ -9409,7 +9409,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %8 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -9425,39 +9425,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %7 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
|
||||
|
@ -9468,16 +9468,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
@ -9490,21 +9490,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
|
@ -9517,7 +9517,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
|
@ -9536,16 +9536,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>, %arg4: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
|
@ -9590,14 +9590,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
|
||||
" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %5 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
|
@ -9610,20 +9610,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
|
@ -9701,7 +9701,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||
" return %3 : !torch.tuple<int, int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union<float, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.number, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
@ -9719,7 +9719,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
|
@ -9730,7 +9730,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
|
@ -9749,12 +9749,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %7 : !torch.bool\n"
|
||||
" }\n"
|
||||
|
@ -9767,7 +9767,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union<float, int>, %arg1: !torch.union<float, int>, %arg2: !torch.union<float, int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
|
@ -9786,19 +9786,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %8 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
|
||||
" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n"
|
||||
" torch.prim.If.yield %8 : !torch.bool\n"
|
||||
" }\n"
|
||||
|
@ -9910,7 +9910,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
|
@ -9925,7 +9925,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
|
@ -9935,7 +9935,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
|
@ -10061,7 +10061,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list<int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
|
@ -10069,7 +10069,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
|
@ -10116,7 +10116,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.union<float, int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>, %arg6: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
|
@ -10313,7 +10313,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.union<float, int, none>, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<number>, %arg3: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
|
@ -10485,8 +10485,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.number) -> !torch.int {\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
|
||||
|
|
|
@ -176,10 +176,17 @@ FailureOr<Value> Torch::adjustFunctionArg(
|
|||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
|
||||
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
|
||||
// be resolved to an `int` or a `float` so we need to derefine to match the
|
||||
// library function signature.
|
||||
// The type `!torch.number` can be an `int`, `float`, or `complex`.
|
||||
// TODO: Add a new type `Torch::ComplexType` to handle the complex case.
|
||||
if (desiredType.isa<Torch::NumberType>() &&
|
||||
operandType.isa<Torch::IntType, Torch::FloatType>()) {
|
||||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// !torch.union<int, float, none> is the type used for optional
|
||||
// `Scalar` inputs. At compile time, such inputs will usually be
|
||||
// resolved to an `int`, `float`, or `None` so we need to derefine
|
||||
// to match the library function signature.
|
||||
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
|
||||
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
|
||||
return containedType
|
||||
|
|
|
@ -772,7 +772,7 @@ def aten〇scalar_tensor〡shape(s: float, dtype: Optional[int] = None, layout:
|
|||
return []
|
||||
|
||||
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
|
||||
def aten〇scalar_tensor〡dtype(s: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
def aten〇scalar_tensor〡dtype(s: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
else:
|
||||
|
@ -1314,7 +1314,7 @@ def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
|
||||
def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, threshold: Union[int, float, complex] = 20) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if is_integer_dtype(self_dtype):
|
||||
return self_dtype
|
||||
|
@ -1395,21 +1395,21 @@ def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0))
|
||||
def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int:
|
||||
def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.bool:
|
||||
return torch.int64
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0))
|
||||
def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int:
|
||||
def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.bool:
|
||||
return torch.int64
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1))
|
||||
def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int:
|
||||
def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float, complex]] = None, max: Optional[Union[int, float, complex]] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if self_dtype == torch.bool:
|
||||
return torch.int64
|
||||
|
@ -1421,7 +1421,7 @@ def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Option
|
|||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1]))
|
||||
def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int:
|
||||
def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float, complex] = 0) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
@ -1478,7 +1478,7 @@ def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], imp
|
|||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0))
|
||||
def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||
def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
@ -1537,14 +1537,14 @@ def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5))
|
||||
def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int:
|
||||
def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex], max_val: Union[int, float, complex]) -> int:
|
||||
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||
if is_integer_dtype(grad_output_dtype):
|
||||
return torch.float32
|
||||
return grad_output_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool}))
|
||||
def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int:
|
||||
def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float, complex] = -1, max_val: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert self_dtype not in [torch.uint8, torch.bool]
|
||||
return self_dtype
|
||||
|
@ -1597,7 +1597,7 @@ def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shap
|
|||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False))
|
||||
def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int:
|
||||
def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex], self_is_result: bool) -> int:
|
||||
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [grad_output_rank, self_rank]
|
||||
|
@ -1617,12 +1617,12 @@ def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int,
|
|||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))
|
||||
def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||
def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0))
|
||||
def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||
def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
@ -1766,7 +1766,7 @@ def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, ind
|
|||
|
||||
@check_dtype_function(
|
||||
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES])
|
||||
def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
|
||||
def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
@ -1820,7 +1820,7 @@ def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output
|
|||
return promoted_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0))
|
||||
def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int:
|
||||
def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex], value: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
|
@ -1890,7 +1890,7 @@ def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
return self_dtype
|
||||
|
||||
@check_dtype_function([Invocation(-1), Invocation(-1.0)])
|
||||
def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int:
|
||||
def prim〇abs〇Scalar〡dtype(a: Union[int, float, complex]) -> int:
|
||||
return get_dtype_of_scalar(a)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||
|
@ -1931,7 +1931,7 @@ def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
|
@ -1941,13 +1941,13 @@ def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
|
@ -1961,7 +1961,7 @@ def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
|
@ -1988,7 +1988,7 @@ def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
|
@ -2010,7 +2010,7 @@ def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
return torch.bool
|
||||
|
||||
@check_dtype_function([
|
||||
|
@ -2019,7 +2019,7 @@ def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[in
|
|||
Invocation(0, 0.0), # int, float
|
||||
Invocation(0, 0), # int, int
|
||||
])
|
||||
def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int:
|
||||
def aten〇add〡dtype(a: Union[int, float, complex], b: Union[int, float, complex]) -> int:
|
||||
ranks: List[Optional[int]] = [None, None]
|
||||
dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
@ -2044,7 +2044,7 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] =
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
||||
|
||||
|
@ -2057,7 +2057,7 @@ def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank
|
|||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||
|
@ -2238,7 +2238,7 @@ def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[in
|
|||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||
|
@ -2249,7 +2249,7 @@ def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty
|
|||
# https://github.com/pytorch/pytorch/issues/100921
|
||||
# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor
|
||||
@check_dtype_function(_check_two_tensor_op(tensor_device="cpu", input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool}, threshold=0))
|
||||
def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int:
|
||||
def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||
assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex"
|
||||
|
@ -2433,7 +2433,7 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
|
|||
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
|
||||
TensorOfShape(3, 4, dtype=torch.float32),
|
||||
TensorOfShape(4, 3, dtype=torch.float32))])
|
||||
def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
mat1_rank, mat1_dtype = mat1_rank_dtype
|
||||
mat2_rank, mat2_dtype = mat2_rank_dtype
|
||||
|
@ -2477,7 +2477,7 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp
|
|||
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
|
||||
TensorOfShape(3, 3, dtype=torch.float32),
|
||||
TensorOfShape(3, 3, dtype=torch.float32))])
|
||||
def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int:
|
||||
def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
tensor1_rank, tensor1_dtype = tensor1_rank_dtype
|
||||
tensor2_rank, tensor2_dtype = tensor2_rank_dtype
|
||||
|
@ -2503,7 +2503,7 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
|
|||
Invocation(TensorOfShape(3, 3, dtype=torch.int32),
|
||||
TensorOfShape(3, 3, dtype=torch.float32),
|
||||
TensorOfShape(3, 3, dtype=torch.float32))])
|
||||
def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int:
|
||||
def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
tensor1_rank, tensor1_dtype = tensor1_rank_dtype
|
||||
tensor2_rank, tensor2_dtype = tensor2_rank_dtype
|
||||
|
@ -2517,7 +2517,7 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype:
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2526,7 +2526,7 @@ def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex], alpha: Union[int, float, complex] = 1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2534,7 +2534,7 @@ def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2542,7 +2542,7 @@ def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2554,7 +2554,7 @@ def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[i
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2563,7 +2563,7 @@ def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0))
|
||||
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_complex_dtype(self_dtype)
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
|
@ -2572,7 +2572,7 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other
|
|||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0))
|
||||
def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int:
|
||||
def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(exponent)]
|
||||
|
@ -2581,7 +2581,7 @@ def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponen
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0))
|
||||
def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int:
|
||||
def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float, complex] = 0.01) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert self_dtype != torch.bool
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
|
@ -2594,7 +2594,7 @@ def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope:
|
|||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
|
||||
def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2611,7 +2611,7 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U
|
|||
TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu")),
|
||||
ErrorInvocation(
|
||||
TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.bfloat16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"))])
|
||||
def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int:
|
||||
def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float, complex] = 1, alpha: Union[int, float, complex] = 1) -> int:
|
||||
batch1_rank, batch1_dtype = batch1_rank_dtype
|
||||
batch2_rank, batch2_dtype = batch2_rank_dtype
|
||||
assert batch1_dtype not in [torch.bool, torch.float16]
|
||||
|
@ -2637,7 +2637,7 @@ def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank
|
|||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)])
|
||||
def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int:
|
||||
def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other: Union[int, float, complex]) -> int:
|
||||
if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)):
|
||||
return torch.int64
|
||||
return torch.float32
|
||||
|
@ -2646,7 +2646,7 @@ def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: U
|
|||
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)])
|
||||
def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int:
|
||||
def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
|
@ -2656,7 +2656,7 @@ def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], se
|
|||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))])
|
||||
def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float, complex], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
ranks: List[Optional[int]] = [None, other_rank]
|
||||
dtypes = [get_dtype_of_scalar(self), other_dtype]
|
||||
|
@ -2755,7 +2755,7 @@ def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r
|
|||
ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified
|
||||
Invocation(end=0, dtype=torch.float16), # Dtype specified
|
||||
Invocation(end=0, dtype=torch.int16)]) # Dtype specified
|
||||
def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
def aten〇arange〡dtype(end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
assert not is_complex_dtype(dtype)
|
||||
return dtype
|
||||
|
@ -2769,7 +2769,7 @@ def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, l
|
|||
ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified
|
||||
Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified
|
||||
Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified
|
||||
def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
def aten〇arange〇start〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
assert not is_complex_dtype(dtype)
|
||||
return dtype
|
||||
|
@ -2785,7 +2785,7 @@ def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, floa
|
|||
ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified
|
||||
Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified
|
||||
Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified
|
||||
def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
def aten〇arange〇start_step〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], step: Union[int, float, complex] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
assert not is_complex_dtype(dtype)
|
||||
return dtype
|
||||
|
@ -2876,7 +2876,7 @@ def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis
|
|||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int:
|
||||
def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int:
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
|
@ -2888,7 +2888,7 @@ def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[Lis
|
|||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int:
|
||||
def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> int:
|
||||
return aten〇std〡dtype(self_rank_dtype)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0))
|
||||
|
@ -2906,7 +2906,7 @@ def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int
|
|||
num_of_tensors=1,
|
||||
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) +
|
||||
[ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)])
|
||||
def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
|
||||
def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float, complex] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
if dtype is not None:
|
||||
|
@ -2971,7 +2971,7 @@ def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] =
|
|||
Invocation([1], 0.0, dtype=torch.int32),
|
||||
Invocation([1], 0.0, dtype=torch.float16),
|
||||
Invocation([1], 0.0, dtype=torch.complex64)])
|
||||
def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
def aten〇full〡dtype(size: List[int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
fill_value_dtype = get_dtype_of_scalar(fill_value)
|
||||
|
@ -3009,7 +3009,7 @@ def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[
|
|||
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64))
|
||||
def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int:
|
||||
def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float, complex], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype if dtype is None else dtype
|
||||
|
||||
|
@ -3143,7 +3143,7 @@ def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Opt
|
|||
return dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes()))
|
||||
def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]:
|
||||
def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float, complex]] = None, keepdim: bool = False) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
if self_dtype == torch.complex64:
|
||||
|
@ -3220,7 +3220,7 @@ def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
|
|||
assert False, "Unexpected dtype!"
|
||||
|
||||
@check_dtype_function([Invocation(0), Invocation(0.0)])
|
||||
def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int:
|
||||
def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float, complex]) -> int:
|
||||
return get_dtype_of_scalar(a)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
|
||||
|
|
|
@ -63,7 +63,7 @@ def get_priority_of_dtype(dtype: int) -> int:
|
|||
return 11
|
||||
assert False, "Cannot determine priority of dtype"
|
||||
|
||||
def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
|
||||
def get_dtype_of_scalar(scalar: Union[int, float, complex]) -> int:
|
||||
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
|
||||
# that when `jit.script`ed converts a float scalar to a tensor
|
||||
# with dtype that corresponds to Python's `float`.
|
||||
|
|
|
@ -49,7 +49,7 @@ def _get_default_value(arg: "SIG_ATTR_TYPE") -> str:
|
|||
|
||||
def _pytype_to_fn_pytype_common(pytype: str) -> str:
|
||||
if "number" in pytype:
|
||||
return pytype.replace("number", "Union[int, float]")
|
||||
return pytype.replace("number", "Union[int, float, complex]")
|
||||
# `torch.device` is lowercase.
|
||||
if pytype == "Device":
|
||||
return "device"
|
||||
|
|
|
@ -72,3 +72,18 @@ func.func @turn_tensors_into_rank_and_dtype_args(%arg0: !torch.vtensor, %arg1: !
|
|||
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.arange(
|
||||
|
||||
// CHECK-LABEL: func.func @derefine_int_to_number() -> !torch.vtensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NUMBER:.*]] = torch.derefine %[[INT1]] : !torch.int to !torch.number
|
||||
// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.arange(%[[NUMBER]], {{.*}}) : (!torch.number, {{.*}}) -> !torch.int
|
||||
func.func @derefine_int_to_number() -> !torch.vtensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%none = torch.constant.none
|
||||
%0 = torch.aten.arange %int1, %none, %none, %none, %none : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue