From de02b56e17739a33cfaa5e1c32739f83ae20fdb4 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Fri, 12 May 2023 13:40:45 -0700 Subject: [PATCH] Replace RefineTypes with dtype functions (#2105) This commit adds dtype functions for all the torch ops that did not previously have one and removes the pass `RefineTypes`, since the abstract interpretation library now takes care of all the dtype propagation. All dtype functions added are tested except for - `aten.embedding` - `aten._embedding_bag` - `aten.embedding_bag` These functions need a change to the testing framework to allow specifying the actual data inside the tensor used for testing. I will fix this in a follow up patch. Co-authored-by: Jiahao Li --- docs/architecture.md | 5 +- .../Dialect/Torch/Transforms/Passes.h | 2 - .../Dialect/Torch/Transforms/Passes.td | 9 - .../Transforms/AbstractInterpLibrary.cpp | 2779 ++++++++++++++++- lib/Dialect/Torch/Transforms/CMakeLists.txt | 1 - .../Transforms/LowerToBackendContract.cpp | 3 +- lib/Dialect/Torch/Transforms/Passes.cpp | 12 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 1770 ----------- .../ReifyAbstractInterpCalculationsUtils.cpp | 17 +- ...implifyAbstractInterpCalculationsUtils.cpp | 252 ++ .../SimplifyAbstractInterpCalculationsUtils.h | 7 + .../Transforms/SimplifyDtypeCalculations.cpp | 7 + .../Transforms/SimplifyShapeCalculations.cpp | 242 +- .../build_tools/abstract_interp_lib_gen.py | 2241 ++++++++++++- .../jit_ir/build_tools/library_generator.py | 49 + .../jit_ir/build_tools/testing_framework.py | 37 +- test/Dialect/Torch/refine-types-branch.mlir | 153 - test/Dialect/Torch/refine-types-ops.mlir | 364 --- test/Dialect/Torch/refine-types.mlir | 238 -- 19 files changed, 5104 insertions(+), 3084 deletions(-) delete mode 100644 lib/Dialect/Torch/Transforms/RefineTypes.cpp delete mode 100644 test/Dialect/Torch/refine-types-branch.mlir delete mode 100644 test/Dialect/Torch/refine-types-ops.mlir delete mode 100644 test/Dialect/Torch/refine-types.mlir diff --git a/docs/architecture.md b/docs/architecture.md index 043bb74ec..e503ba40d 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -242,8 +242,7 @@ The `torchscript-module-to-torch-backend-pipeline` contains the set of simplific 1. LowerToBackendContract: This pass iteratively applies a simplification pipeline until the backend contract is reached. The simplification pipeline consists of: - Standard canonicalization. - - Shape refinement. See [shape_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/shape_lib.md) for detail - - DType refinement. See `RefineTypes`. + - Shape and Dtype refinement. See [abstract_interp_lib.md](https://github.com/llvm/torch-mlir/blob/main/docs/abstract_interp_lib.md) for detail - Decomposing ops into more primitive ops. See `DecomposeComplexOps`. ### Layering of the PyTorch Dependency @@ -414,8 +413,6 @@ DON'T use a unit test if your lowering pattern could be described as a trivial like your unit test is just rewriting `b.create<...>(...)` into `CHECK: ...` then it is probably not a useful unit test. -DON'T add a unit test for trivial changes to RefineTypes. - With the exceptions above, all changes should include appropriate unit tests, as is standard in the LLVM and MLIR community. This includes full coverage of all canonicalizations, pretty printing, passes, errors, and diagnostics. diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 8e817374b..84efddcc9 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -92,8 +92,6 @@ void createTorchDtypeRefinementPipeline( std::unique_ptr> createAdjustCallingConventionsPass(); -std::unique_ptr> createRefineTypesPass(); - std::unique_ptr> createInlineGlobalSlotsPass(); std::unique_ptr> diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 8369d1d3d..d9331322f 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -126,15 +126,6 @@ def AdjustCallingConventions }]; } -def RefineTypes : Pass<"torch-refine-types", "func::FuncOp"> { - let summary = "Refine types"; - let constructor = "mlir::torch::Torch::createRefineTypesPass()"; - let description = [{ - Refines types of the program. Currently, this means shapes and dtypes of - tensors/arrays. - }]; -} - def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> { let summary = "Inlines torch.global_slot ops."; let constructor = "mlir::torch::Torch::createInlineGlobalSlotsPass()"; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c8271fc6c..8b29aef71 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6111,6 +6111,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.testing_framework._convert_dtype_to_int(%arg0: !torch.int) -> !torch.int {\n" +" return %arg0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.triu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6159,33 +6162,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int5 = torch.constant.int 5\n" -" %int15 = torch.constant.int 15\n" -" %true = torch.constant.bool true\n" -" %int7 = torch.constant.int 7\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" -" torch.prim.If.yield %0#1 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" }\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6417,24 +6393,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" -" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" -" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" -" return %0 : !torch.int\n" -" }\n" -" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" -" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" -" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7067,14 +7025,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.atan2\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7401,36 +7351,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" -" %int10 = torch.constant.int 10\n" -" %int9 = torch.constant.int 9\n" -" %int5 = torch.constant.int 5\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" -" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %8 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7784,85 +7704,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %int7 = torch.constant.int 7\n" -" %int6 = torch.constant.int 6\n" -" %int10 = torch.constant.int 10\n" -" %true = torch.constant.bool true\n" -" %int9 = torch.constant.int 9\n" -" %0 = torch.prim.Uninitialized : !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %5 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %5 : !torch.bool\n" -" }\n" -" %4 = torch.prim.If %3 -> (!torch.int) {\n" -" torch.prim.If.yield %1#1 : !torch.int\n" -" } else {\n" -" %5 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" %7 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %8 = torch.prim.If %7 -> (!torch.int) {\n" -" torch.prim.If.yield %int10 : !torch.int\n" -" } else {\n" -" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" %10 = torch.prim.If %9 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %12 = torch.prim.If %11 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %16 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %16 : !torch.bool\n" -" }\n" -" %15 = torch.prim.If %14 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.int\n" -" }\n" -" torch.prim.If.yield %15 : !torch.int\n" -" }\n" -" torch.prim.If.yield %8 : !torch.int\n" -" }\n" -" torch.prim.If.yield %6 : !torch.int\n" -" }\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" " %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" @@ -7907,6 +7748,841 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" %4 = torch.aten.ne.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %arg0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_complex_dtypes() -> !torch.list {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__._get_dtype_of_floating_point_op(%1#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = func.call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %3:2 = torch.prim.If %2 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %6:2 = torch.prim.If %5 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %6#0, %6#1 : !torch.bool, !torch.int\n" +" }\n" +" %4 = torch.prim.If %3#0 -> (!torch.int) {\n" +" torch.prim.If.yield %3#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli\"(%arg0: !torch.tuple, %arg1: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_to\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ceil\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_max\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp_min\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clamp\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.clone\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.constant_pad_nd\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copy\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cpu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expand\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flatten.using_ints\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gather\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list>, %arg1: !torch.list) -> !torch.int {\n" +" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list>, !torch.list) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gelu\"(%arg0: !torch.tuple, %arg1: !torch.str) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardsigmoid\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardswish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hardtanh\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %int0, %int11 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.aten.__contains__.int_list %1, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put.hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_put\"(%arg0: !torch.tuple, %arg1: !torch.list>>, %arg2: !torch.tuple, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor_hacked_twin\"(%arg0: !torch.tuple, %arg1: !torch.list>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.list>>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill_.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_fill.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.masked_select\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.narrow\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.neg\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.numpy_T\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pad\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.permute\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.prelu\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.relu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.repeat\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._reshape_alias\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reshape\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.resize_\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.round\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.select_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.slice.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze.dim\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.squeeze\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.t\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unsqueeze\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zero_\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.abs.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union) -> !torch.int {\n" +" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union -> !torch.tensor\n" +" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %1#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.eq.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.gt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ge.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_and\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_not\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_or\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logical_xor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lt.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.le.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.union, %arg1: !torch.union) -> !torch.int {\n" " %none = torch.constant.none\n" " %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" @@ -7916,6 +8592,1785 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_and.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_xor.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%arg0: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Cannot determine priority of dtype\"\n" +" %int15 = torch.constant.int 15\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %int6 = torch.constant.int 6\n" +" %int7 = torch.constant.int 7\n" +" %int8 = torch.constant.int 8\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1 = torch.aten.eq.int %arg0, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" %5 = torch.aten.eq.int %arg0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int2 : !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" %9 = torch.aten.eq.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %11 = torch.aten.eq.int %arg0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %arg0, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %arg0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %arg0, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %arg0, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %21 = torch.aten.eq.int %arg0, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %23 = torch.aten.eq.int %arg0, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %24 = torch.prim.If %23 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" torch.prim.If.yield %22 : !torch.int\n" +" }\n" +" torch.prim.If.yield %20 : !torch.int\n" +" }\n" +" torch.prim.If.yield %18 : !torch.int\n" +" }\n" +" torch.prim.If.yield %16 : !torch.int\n" +" }\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Tensor_mode\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %10 = torch.aten.ne.int %4, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Result dtype for aten.floor_divide bool\"\n" +" %int11 = torch.constant.int 11\n" +" %str_0 = torch.constant.str \"AssertionError: `other` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.aten.ne.int %8, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%0#1) : (!torch.int) -> !torch.int\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_priority_of_dtype(%1#1) : (!torch.int) -> !torch.int\n" +" %4 = torch.aten.lt.int %2, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.maximum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.minimum\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int15 = torch.constant.int 15\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int15, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = torch.aten.ne.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %9 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mv\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.threshold_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Result dtype for aten.threshold_backward cannot be bool or float16\"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %str_0 = torch.constant.str \"AssertionError: `self` cannot be complex\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: `grad_output` cannot be complex\"\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.aten.__contains__.int_list %9, %8 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %12 = torch.aten.__contains__.int_list %11, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %13 = torch.aten.__not__ %12 : !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %10 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.bool, %arg8: !torch.list, %arg9: !torch.int, %arg10: !torch.list) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bincount\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %5 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %7 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sub.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mul.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.div.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %5 = torch.prim.ListConstruct %0#1, %4 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.leaky_relu\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%3) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#1, %3 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.union, %arg4: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.aten.__contains__.int_list %2, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %int11, %int5 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.__contains__.int_list %5, %1#1 : !torch.list, !torch.int -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.self\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.union) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int4 = torch.constant.int 4\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %4 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %5 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %5 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarOther\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.where.ScalarSelf\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list>\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %7 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange\"(%arg0: !torch.union, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%6) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.arange.start_step\"(%arg0: !torch.union, %arg1: !torch.union, %arg2: !torch.union, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int6 = torch.constant.int 6\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union) -> !torch.int\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%7) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" torch.prim.If.yield %6 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sum.dim_IntList\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.sum\"(%arg0, %arg3) : (!torch.tuple, !torch.optional) -> !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.argmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.any.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amax\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.max\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = call @\"__torch_mlir_dtype_fn.aten.mean.dim\"(%arg0, %0, %false, %arg1) : (!torch.tuple, !torch.optional>, !torch.bool, !torch.optional) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.std.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.int\"(%arg0: !torch.int, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tensor.bool\"(%arg0: !torch.bool, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty.memory_format\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full\"(%arg0: !torch.list, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.union) -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.zeros_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ones_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.empty_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.full_like\"(%arg0: !torch.tuple, %arg1: !torch.union, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_zeros\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_ones\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.new_empty_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %5 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._to_copy\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" return %arg1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.device\"(%arg0: !torch.tuple, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional) -> !torch.int {\n" +" return %arg2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.to.other\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%2) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.correction\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.union, %arg3: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean\"(%arg0: !torch.tuple, %arg1: !torch.bool) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.tuple) {\n" +" %5 = torch.prim.TupleConstruct %int6, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %5 : !torch.tuple\n" +" } else {\n" +" %5 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.tuple) {\n" +" %7 = torch.prim.TupleConstruct %int7, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" } else {\n" +" %7 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %7 : !torch.tuple\n" +" }\n" +" torch.prim.If.yield %6 : !torch.tuple\n" +" }\n" +" return %4 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atan\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ScalarImplicit\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Unexpected dtype!\"\n" +" %int4 = torch.constant.int 4\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %int7 = torch.constant.int 7\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" %6 = func.call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %9 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %9 = torch.aten.eq.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union) -> !torch.int {\n" +" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int5 = torch.constant.int 5\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.If %arg2 -> (!torch.int) {\n" +" %2 = torch.aten.eq.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log_softmax.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._embedding_bag\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4, %int4, %int4 : !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.bucketize.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.prim.If %arg2 -> (!torch.int) {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" }\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.squeeze\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index ce577cf5b..9c8bcda94 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -12,7 +12,6 @@ add_mlir_library(TorchMLIRTorchPasses RecomposeComplexOps.cpp ReduceOpVariants.cpp RefinePublicReturn.cpp - RefineTypes.cpp ReifyShapeCalculations.cpp ReifyDtypeCalculations.cpp ReifyAbstractInterpCalculationsUtils.cpp diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ac077ca2f..f7cf3c95a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -103,7 +103,8 @@ static LogicalResult checkType(Operation *op, Type type, ->emitError( "unsupported by backend contract: tensor with unknown dtype") .attachNote() - .append("this is likely due to a missing case in RefineTypes"); + .append("this is likely due to a missing transfer function in " + "abstract_interp_lib_gen.py"); } else { return failure(); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 5ed5d53bd..407e90247 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -119,20 +119,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // Update the return op to return value tensors. pm.addPass(Torch::createRefinePublicReturnPass()); pm.addNestedPass(createCanonicalizerPass()); - // Do shape refinement. - // This should be run before RefineTypes (which primarily does dtype - // inference), because Torch type promotion rules actually depend on the shape - // of the operand. + // Do shape and dtype refinement. + // Shape refinement should be run before dtype refinement because Torch type + // promotion rules actually depend on the shape of the operand. createTorchShapeRefinementPipeline(pm, options); createTorchDtypeRefinementPipeline(pm, options); - // Refine types in the program, which mainly means inferring dtypes of ops. - pm.addNestedPass(Torch::createRefineTypesPass()); // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. pm.addPass(Torch::createRefinePublicReturnPass()); - // This can fold away some branches given the information got from - // RefineTypes before doing maximize value sematics which only works with - // basic blocks. pm.addNestedPass(createCanonicalizerPass()); if (options.decompose) { pm.addNestedPass( diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp deleted file mode 100644 index f95ea6e7f..000000000 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ /dev/null @@ -1,1770 +0,0 @@ -//===- RefineTypes.cpp ------------------------*- C++-*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// -// This file implements a dataflow analysis primarily used to infer dtypes -// of tensors in the program. Shapes are handled separately with a -// more involved mechanism (see createTorchShapeRefinementPipeline). -// -// The analysis performed in this file is implemented with MLIR's dataflow -// analysis framework, which was originally developed for SCCP, and so is an -// optimistic framework. It proceeds by assuming that all Value's have a -// maximally optimistic ("bottom") lattice element associated with them, and -// then the `visitOperation` method (and some built-in handling for control -// flow) gradually relaxes that optimism until the lattice elements associated -// with each Value either settle to a (optimistic) fixed-point, or need to fall -// back on a suitable pessimistic lattice element. -// -// A note on dataflow analysis terminology: -// In dataflow analysis (or other contexts where lattices appear), it is -// frequently confusing because meet/join and related aspects of lattices -// (such as what is "up"/"down" or "top"/"bottom" in the lattice) are dual to -// each other and so a convention has to be chosen to ground the terminology. -// -// In the context of this dataflow analysis, we use the terms with the following -// senses (many examples are given to build intuition): -// - "top" means the state of least specific knowledge (i.e. most pessimistic -// possible knowledge) -// - "bottom" is the lattice element with such specific knowledge that "join"ing -// with it is an identity operation. (i.e. most optimistic possible knowledge) -// - "moving down the lattice" means moving towards having more specific -// knowledge -// - "moving up the lattice" means moving towards having less specific knowledge -// - "top" means the state of least specific knowledge (i.e. most pessimistic -// possible knowledge) -// - "meet" means -// - "move down the lattice" (greatest lower bound) -// - "constrict" -// - "refine" -// - "assume union of information from both lattice elements" -// - "join" means -// - "move up the lattice" (least upper bound) -// - "widen" -// - "relax" -// - "assume intersection of information from both lattice elements" -// -// Note: This pass is kept completely separate from -// createShapeRefinementPipeline because any interaction between the two would -// usually require a fixed-point iteration to work in generality. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" - -#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" - -using namespace mlir; -using namespace mlir::torch; -using namespace mlir::torch::Torch; - -// ----------------------------------------------------------------------------- -// Analysis. -// ----------------------------------------------------------------------------- - -static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) { - FailureOr result = - getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); - return failed(result) ? Type() : *result; -} - -static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype, - Type defaultDtype) { - int64_t dtypeInt; - if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt))) - return getTypeForDTypeInteger(context, dtypeInt); - else if (optionalDtype.getType().isa()) - return defaultDtype; - return Type(); -} - -// Get the kind enum for `ValueKnowledge.kind`. -static torch_upstream::TypeKind getTypeKind(Type type) { - if (type.isa()) - return torch_upstream::TypeKind::NumberType; - if (type.isa()) - return torch_upstream::TypeKind::IntType; - if (type.isa()) - return torch_upstream::TypeKind::FloatType; - if (type.isa()) - return torch_upstream::TypeKind::TensorType; - if (type.isa()) - return torch_upstream::TypeKind::NoneType; - // Skip the Torch::OptionalType on purpose because optional knowledge is - // tracked separately. See comments for `ValueKnowledge.kind` field. - return torch_upstream::TypeKind::AnyType; -} - - -enum class OptionalKnowledge { - unKnown, - isNone, - notNone, -}; - -/// Returns the OptionalKnowledge that assumes information from both `lhs` and -/// `rhs`. Returns `std::nullopt` if the knowledges are contradictory. -static std::optional -meetOptionalKnowledge(OptionalKnowledge lhs, OptionalKnowledge rhs) { - if (lhs == OptionalKnowledge::unKnown) - return rhs; - if (rhs == OptionalKnowledge::unKnown) - return lhs; - if (lhs == rhs) - return lhs; - return std::nullopt; -} - -static OptionalKnowledge joinOptionalKnowledge(OptionalKnowledge lhs, - OptionalKnowledge rhs) { - if (lhs == rhs) - return lhs; - return OptionalKnowledge::unKnown; -} - -namespace { -// Statically known information for a particular Value. -// -// This struct currently tracks information relevant for tensor/array-like -// shaped types as well as whether an object is None or not, namely -// !torch.optional. It is fine to associate a `ValueKnowledge` with a non-shaped -// type or non OptionalType as long as it is in the default "no knowledge" -// state returned by `getPessimisticValueState`. The important invariant is that -// we cannot claim to know something about a value which is false. -// This class could also be called "dataflow facts", "lattice value", etc. -struct ValueKnowledge { - ValueKnowledge() = default; - ValueKnowledge(Type dtype, Type scalarType, - OptionalKnowledge optionalKnowledge, - torch_upstream::TypeKind kind) - : isInitialized(true), dtype(dtype), scalarType(scalarType), kind(kind), - optional(optionalKnowledge) {} - - void print(raw_ostream &os) const { - os << "ValueKnowledge("; - if (!isInitialized) { - os << "uninitialized)"; - return; - } - if (dtype) - os << "dtype=" << dtype; - if (scalarType) - os << ", scalarType=" << scalarType; - if (optional != OptionalKnowledge::unKnown) - os << ", optional=" << (int)optional; - os << ", kind=" << (int)kind << ")"; - } - void setScalarType(Type type) { - bool isValidScalarType = type.isa(); - (void)isValidScalarType; - assert(isValidScalarType && - "scalarType can only be one of NumberType, IntType and FloatType"); - scalarType = type; - kind = getTypeKind(type); - } - - // Get the static knowledge intrinsic to `type`. - static ValueKnowledge getKnowledgeFromType(Type type) { - ValueKnowledge result = getPessimisticValueState(type.getContext()); - result.kind = getTypeKind(type); - switch (result.kind) { - case torch_upstream::TypeKind::TensorType: - result.dtype = type.cast().getOptionalDtype(); - result.optional = OptionalKnowledge::notNone; - return result; - case torch_upstream::TypeKind::NumberType: - case torch_upstream::TypeKind::IntType: - case torch_upstream::TypeKind::FloatType: - result.scalarType = type; - result.optional = OptionalKnowledge::notNone; - return result; - case torch_upstream::TypeKind::NoneType: - result.optional = OptionalKnowledge::isNone; - return result; - default: - if (type.isa()) - return result; - // All other types that are not optional type. - result.optional = OptionalKnowledge::notNone; - return result; - } - } - - // Return a pessimistic/conservative value state without assuming any knowlege - // about the IR. - static ValueKnowledge getPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), Type(), OptionalKnowledge::unKnown, - torch_upstream::TypeKind::AnyType); - } - // Return a pessimistic/conservative value state only using knowlege already - // recorded in the IR. - static ValueKnowledge getPessimisticValueState(Value value) { - return getKnowledgeFromType(value.getType()); - } - - static ValueKnowledge getNotNonePessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone, - torch_upstream::TypeKind::AnyType); - } - - static ValueKnowledge getTensorPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), Type(), OptionalKnowledge::notNone, - torch_upstream::TypeKind::TensorType); - } - - static ValueKnowledge getScalarPessimisticValueState(MLIRContext *context) { - return ValueKnowledge(Type(), NumberType::get(context), - OptionalKnowledge::notNone, - torch_upstream::TypeKind::NumberType); - } - - bool operator==(const ValueKnowledge &rhs) const { - if (!isInitialized && !rhs.isInitialized) - return true; - return isInitialized && rhs.isInitialized && - std::make_tuple(dtype, optional) == - std::make_tuple(rhs.dtype, rhs.optional); - } - - // Return true if the `refinedType` has more concrete type info than `type`. - static bool hasStrictlyMoreRefinedTypeInfo(const ValueKnowledge &refinedType, - const ValueKnowledge &type) { - if (!refinedType.isInitialized) - return false; - if (!type.isInitialized) - return true; - - if (type.kind == torch_upstream::TypeKind::AnyType && - refinedType.kind != torch_upstream::TypeKind::AnyType) - return true; - - // If both are tensors but `type` doesn't have concrete dtype info. - if (refinedType.kind == torch_upstream::TypeKind::TensorType && - type.kind == torch_upstream::TypeKind::TensorType) { - return refinedType.dtype && !type.dtype; - } - - if (refinedType.scalarType && type.scalarType) - return isValidSubtype(refinedType.scalarType, type.scalarType); - - return false; - } - - // Given two pieces of static knowledge, intersect the facts that are known in - // both knowledges. This always produces knowledge that has less (or equal) - // facts than both the lhs and rhs. - // - // This operator is used, for example, at control flow join points: if - // predecessors A and B forward a block argument to a common successor C, then - // we need to calculate what can be known for sure about the block argument if - // the control flow is coming from either A or B. So we can't assume facts - // just because they are true on one control flow edge. They must be true on - // both. - static ValueKnowledge join(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - if (!lhs.isInitialized) - return rhs; - if (!rhs.isInitialized) - return lhs; - - // Mental model: All conditions are checking how to change from the safe "no - // knowledge" default-initialized state to a state with more knowledge - // consistent with lhs and rhs. - ValueKnowledge result = joinTypes(lhs, rhs); - result.optional = joinOptionalKnowledge(lhs.optional, rhs.optional); - return result; - } - - static ValueKnowledge joinTypes(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - if (!lhs.isInitialized) - return rhs; - if (!rhs.isInitialized) - return lhs; - - if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) - return rhs; - if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) - return lhs; - if (lhs == rhs) - return lhs; - return getPessimisticValueState(nullptr); - } - - // Given two pieces of static knowledge, calculate new knowledge that assumes - // the facts from both. - // If the two pieces of knowledge are contradictory, std::nullopt is returned. - static std::optional meet(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - if (!lhs.isInitialized) - return lhs; - if (!rhs.isInitialized) - return rhs; - - std::optional knowledge = meetTypes(lhs, rhs); - - if (!knowledge.has_value()) - return std::nullopt; - ValueKnowledge result = knowledge.value(); - - std::optional optional = - meetOptionalKnowledge(lhs.optional, rhs.optional); - if (!optional.has_value()) - return std::nullopt; - result.optional = optional.value(); - return result; - } - - static std::optional meetTypes(const ValueKnowledge &lhs, - const ValueKnowledge &rhs) { - if (!lhs.isInitialized) - return lhs; - if (!rhs.isInitialized) - return rhs; - - if (hasStrictlyMoreRefinedTypeInfo(lhs, rhs)) - return lhs; - if (hasStrictlyMoreRefinedTypeInfo(rhs, lhs)) - return rhs; - if (lhs == rhs) - return lhs; - return std::nullopt; - } - - // We start in the uninitialized state by default. - bool isInitialized = false; - - // The dtype of a tensor. - // This is equal to nullptr for the follow cases: - // 1. it is unknown whether the value is a tensor or not, ie the `kind` field - // is torch_upstream::TypeKind::AnyType. - // 2. the value is a tensor type but the dtype is unknown. - // 3. the value is not a tensor type. - Type dtype; - - // The type of a scalar. - // This is equal to nullptr for the follow cases: - // 1. it is unknown whether the value is a scalar or not, ie the `kind` field - // is torch_upstream::TypeKind::AnyType. - // 2. the value is not a scalar type. - Type scalarType; - - // The type kind. If it's torch_upstream::TypeKind::AnyType, - // all the type fields are nullptr. Note that the `kind` never equals to - // torch_upstream::TypeKind::OptionalType because optional knowledge is - // tracked separately through the `optional` field. - torch_upstream::TypeKind kind; - - // What is known about an optional value. - // When equal to OptionalKnowledge::notNone, the type info is kept in type - // fields like `dtype`, `scalarType`. - // When equal to OptionalKnowledge::isNone or OptionalKnowledge::unKnown, the - // other type fields are currently nullptr. It might worth considering - // tracking wrapped type info when OptionalKnowledge::unKnown in the future. - OptionalKnowledge optional; -}; -} // namespace - -using ValueState = dataflow::Lattice; -// Register TypeID for the dataflow framework. -MLIR_DECLARE_EXPLICIT_TYPE_ID(ValueState) -MLIR_DEFINE_EXPLICIT_TYPE_ID(ValueState) - -namespace { -// Forward intraprocedural dataflow for type information. -class TypeAnalysis : public dataflow::SparseDataFlowAnalysis< - dataflow::Lattice> { -public: - using BaseT = - dataflow::SparseDataFlowAnalysis>; - using BaseT::SparseDataFlowAnalysis; - - // Compute the knowledge for the results of an op, based on the knowledge of - // the operands and any information intrinsic to `op`. - void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) final; - - void setToEntryState(ValueState *lattice) override { - auto refType = lattice->getPoint().getType(); - auto knowledge = ValueKnowledge::getKnowledgeFromType(refType); - propagateIfChanged(lattice, lattice->join(knowledge)); - } - -private: - // Get the MLIR type of the tensor dtype given the dtype integer value and the - // input dtype. When DType is None the type is inferred from the input dtype. - void fillInDTypeGivenDTypeIntAndInputDType(ValueKnowledge &knowledge, - Value dtype, Type inputDType); - - // Get the MLIR type of the tensor dtype given the dtype integer value and - // data type of torch type. When DType is None the type is inferred from the - // data type. - void fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, Value dtype, - Type dataType); - - /// Incorporates `knowledge` into the lattice state of `v`. - /// - /// This method should be used instead of - /// `getLatticeElement(v).join(knowledge)`, because this method knows how to - /// correctly handle the case of existing static knowledge from the type - /// of `v`. - void incorporateKnowledge(Value v, const ValueKnowledge &knowledge); - - void visitAtenLinearOp(AtenLinearOp op, - ArrayRef operands); - void visitAtenArangeStartStepOp(AtenArangeStartStepOp op); - void visitAtenArangeStartOp(AtenArangeStartOp op); - void visitAtenArangeOp(AtenArangeOp op); - void visitAtenArangeLikeOpHelper(Operation *op, std::optional start, - Value end, std::optional step, - Value dtype); - void visitReductionAlongAllDimsOp(Operation *op, Type dtype, - ArrayRef operands); - void visitReductionAlongDimIntListOp(Operation *op, Value dim, Value keepdim, - Type dtype, - ArrayRef operands); - void visitReductionAlongDimIntOp(Operation *op, Value dim, Value keepdim, - Type dtype, - ArrayRef operands, - int resNum = 0); - template void visitScalarToTensorConversionOp(OpTy op); - void visitAtenTensorOp(AtenTensorOp op); - template - void visitConstantTensorAllocOp(OpTy op, std::optional dataType); - template - void visitConstantTensorAllocLikeOp(OpTy op, - ArrayRef operands); - template - void visitConstantTensorNewLikeOp(OpTy op, - ArrayRef operands); - template - void visitAtenToDtypeLikeOp(OpTy op, ArrayRef operands); - template - void visitTypeConversionOp(OpTy op, ArrayRef operands); - template - void visitAtenCatLikeOp(OpTy op, ArrayRef operands); - - template - void visitAtenSoftmaxLikeOp(OpTy op, ArrayRef operands); - template - void visitAten_SoftmaxLikeOp(OpTy op, ArrayRef operands); - - void visitNumToTensorOp(PrimNumToTensorScalarOp op); - void visitBinaryScalarOp(Operation *op, - ArrayRef operands); - void visitAtenScalarImplicitOp(AtenScalarImplicitOp op, - ArrayRef operands); - void visitAtenEmbeddingBagOp(Operation *op); -}; -} // namespace - -static torch_upstream::ResultTypeState -updateResultTypeState(Type scalarType, - const torch_upstream::ResultTypeState &inState) { - assert(isBuiltInType(scalarType) && "scalarType must be builtin type"); - torch_upstream::ResultTypeState new_state = inState; - torch_upstream::ScalarType current = getScalarTypeForType(scalarType); - new_state.wrappedResult = - promote_skip_undefined(inState.wrappedResult, current); - return new_state; -} - -// This mostly mirrors the update_result_type_state in -// aten/src/ATen/native/TypeProperties.* except that we don't not support -// is_wrapped_number as it is a runtime property. From perspective of -// torch-mlir, all zero dim tensor are the same priority. -// -// Normally, tensor dimensions need to be known at compile time to do type -// promotion. `skipRankCheck`, when equal to `true`, is used for special cases -// where rank doesn't matter. This could be because that operands can sometimes -// guaranteed to be none zero rank or that the result -// torch_upstream::ResultTypeState is promoted with a scalar which is guaranteed -// to be lower priority. -// -// The `rankIsNonZero` argument indicates whether the rank is nonzero, zero, or -// unknown (None variant of the optional). -static torch_upstream::ResultTypeState -updateResultTypeState(const ValueKnowledge *tensor, - std::optional rankIsNonZero, - const torch_upstream::ResultTypeState &inState, - bool skipRankCheck = false) { - if (!rankIsNonZero.has_value() && !skipRankCheck) - return torch_upstream::ResultTypeState{}; - assert(tensor->dtype && "tensor.dtype must be not none"); - - torch_upstream::ResultTypeState new_state = inState; - torch_upstream::ScalarType current = getScalarTypeForType(tensor->dtype); - if (skipRankCheck || rankIsNonZero.value()) - new_state.dimResult = promote_skip_undefined(inState.dimResult, current); - else - new_state.zeroResult = promote_skip_undefined(inState.zeroResult, current); - - return new_state; -} - -// Type promotion helper for operators where only scalar operands participating -// in type promotion like AtenAddOp. -// -// \return The return type is a TorchType. -static Type getPromotedResultScalarType(ArrayRef scalarTypes) { - torch_upstream::ResultTypeState state = {}; - for (const Type &scalarType : scalarTypes) { - state = - updateResultTypeState(getBuiltInTypeForTorchScalar(scalarType), state); - } - FailureOr result = getTorchTypeForScalarType( - scalarTypes[0].getContext(), result_type(state)); - if (failed(result)) - return Type(); - return *result; -} - -// Returns most generic type Type() if the tensor dtype is unknown. -static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) { - if (!tensor->dtype) - return Type(); - torch_upstream::ResultTypeState state = {}; - // No need to check if rank is zero for tensor because scalar uses - // wrappedResult which is a lower priority than both dimResult and zeroResult. - state = updateResultTypeState(tensor, /*rankIsNonZero=*/std::nullopt, state, - /*skipRankCheck=*/true); - state = - updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state); - FailureOr result = - getTypeForScalarType(scalarType.getContext(), result_type(state)); - return failed(result) ? Type() : *result; -} - -static SmallVector> -getRankIsNonZeroArray(ValueRange values) { - SmallVector> rankIsNonZero; - for (Value v : values) { - if (auto tensorType = v.getType().dyn_cast()) { - if (tensorType.hasSizes()) { - rankIsNonZero.push_back(tensorType.getSizes().size() != 0); - } else { - rankIsNonZero.push_back(std::nullopt); - } - } - } - return rankIsNonZero; -} - -// Normally, tensor dimensions need to be known at compile time to do type -// promotion. `skipRankCheck`, when equal to true, can be used to indicate -// special cases that tensor operands are guaranteed to be not zero dimension -// like operands of `aten.conv2d` or `aten.mm` assuming no runtime error. -// -// Returns most generic type Type() if the tensor dtype is unknown. -static Type getPromotedResultType(MLIRContext *context, - ArrayRef tensors, - ArrayRef> rankIsNonZero, - bool skipRankCheck = false) { - torch_upstream::ResultTypeState state = {}; - assert(tensors.size() == rankIsNonZero.size()); - for (auto t : llvm::zip(tensors, rankIsNonZero)) { - const ValueKnowledge *tensor = std::get<0>(t); - std::optional rankIsNonZero = std::get<1>(t); - if (!tensor->dtype) - return Type(); - state = updateResultTypeState(tensor, rankIsNonZero, state, skipRankCheck); - } - FailureOr result = getTypeForScalarType(context, result_type(state)); - return failed(result) ? Type() : *result; -} - -static Type getPromotedResultTypeAssumingNonZeroRank( - MLIRContext *context, ArrayRef tensors) { - SmallVector> rankIsNonZero(tensors.size(), true); - return getPromotedResultType(context, tensors, rankIsNonZero, - /*skipRankCheck=*/true); -} - -void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType( - ValueKnowledge &knowledge, Value dtype, Type inputDType) { - assert(!inputDType || - isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); - int64_t dtypeInt; - if (dtype.getType().isa()) - knowledge.dtype = inputDType; - else if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) - knowledge.dtype = getTypeForDTypeInteger(dtype.getContext(), dtypeInt); - else if (auto primDtypeOp = dyn_cast(dtype.getDefiningOp())) - knowledge.dtype = getLatticeElement(primDtypeOp.getA())->getValue().dtype; -} - -void TypeAnalysis::fillInDTypeGivenDTypeAndDataType(ValueKnowledge &knowledge, - Value dtype, - Type dataType) { - assert(isa(dataType.getDialect()) && - "`dataType` must be a torch type"); - Type dtypeForDataType = getDefaultDtypeForTorchScalar(dataType); - fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, dtypeForDataType); -} - -void TypeAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { - - // These ops have results that are dynamically the same as their operands. - if (isa(op)) { - incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - return; - } - - // Take dtype from first operand. - if (isa(op)) { - return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); - } - - // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type dtype = operands[0]->getValue().dtype; - if (dtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (dtype.isa()) - knowledge.dtype = dtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Take dtype from second operand. - if (isa(op)) { - auto self = operands[1]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always i1. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = IntegerType::get(op->getContext(), 1); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always si64. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote the two dtypes assuming non-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote the two dtypes assuming possibly-zero rank. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Dtype is always float32, except for bfloat16, float64 and nullptr after - // promotion and assuming possible-zero rank. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type promotedDtype = getPromotedResultType( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()}, - getRankIsNonZeroArray(op->getOperands())); - if (promotedDtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (promotedDtype.isa()) - knowledge.dtype = promotedDtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote three dtypes. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue(), - &operands[2]->getValue()}); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - if (auto linear = llvm::dyn_cast(op)) { - visitAtenLinearOp(linear, operands); - return; - } - - // Promote LHS with scalar RHS. - if (isa(op)) { - auto lhs = operands[0]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultType( - op->getContext(), {&operands[1]->getValue(), &operands[2]->getValue()}, - getRankIsNonZeroArray(op->getOperands().slice(1, 2))); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - Value lhsScalar = op->getOperand(1); - Value rhsScalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getDefaultDtypeForTorchScalar(getPromotedResultScalarType( - {lhsScalar.getType(), rhsScalar.getType()})); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto lhs = operands[1]->getValue(); - Value scalar = op->getOperand(2); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&lhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // Promote 2nd and 3rd operands. - if (isa(op)) { - auto rhs = operands[2]->getValue(); - Value scalar = op->getOperand(1); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = getPromotedResultDType(&rhs, scalar.getType()); - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - - // 2 results take dtype from first operand. - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - // 3 results take dtype from first operand. - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = self.dtype; - auto result2Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result2Knowledge.dtype = self.dtype; - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - incorporateKnowledge(op->getResult(2), result2Knowledge); - return; - } - - if (isa(op)) { - auto self = operands[0]->getValue(); - auto result0Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result0Knowledge.dtype = self.dtype; - auto result1Knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - result1Knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(op->getResult(0), result0Knowledge); - incorporateKnowledge(op->getResult(1), result1Knowledge); - return; - } - - if (auto arange = dyn_cast(op)) { - visitAtenArangeOp(arange); - return; - } - if (auto arangeStart = dyn_cast(op)) { - visitAtenArangeStartOp(arangeStart); - return; - } - if (auto arangeStartStep = dyn_cast(op)) { - visitAtenArangeStartStepOp(arangeStartStep); - return; - } - - if (auto sum = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sum.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sum.getContext(), sum.getDtype(), defaultDtype); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - if (auto sumDimIntList = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - if (!defaultDtype) { - incorporateKnowledge( - sumDimIntList.getResult(), - ValueKnowledge::getTensorPessimisticValueState(op->getContext())); - return; - } - // If the input dtype is bool, the result type should be i64. - if (defaultDtype.isInteger(1)) - defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), - sumDimIntList.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.getDim(), - sumDimIntList.getKeepdim(), dtype, operands); - return; - } - if (auto meanDim = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(meanDim.getContext(), meanDim.getDtype(), defaultDtype); - visitReductionAlongDimIntListOp(meanDim, meanDim.getDim(), meanDim.getKeepdim(), - dtype, operands); - return; - } - if (auto argmax = dyn_cast(op)) { - Value dim = argmax.getDim(); - Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - if (dim.getType().isa()) { - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } - if (dim.getType().isa()) { - visitReductionAlongDimIntOp(argmax, argmax.getDim(), argmax.getKeepdim(), dtype, - operands); - return; - } - } - if (auto anyDim = dyn_cast(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongDimIntOp(anyDim, anyDim.getDim(), anyDim.getKeepdim(), dtype, - operands); - return; - } - if (auto maxDim = dyn_cast(op)) { - Type firstResDtype = operands[0]->getValue().dtype; - Type secondResDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - firstResDtype, operands); - visitReductionAlongDimIntOp(maxDim, maxDim.getDim(), maxDim.getKeepdim(), - secondResDtype, operands, /*resNum=*/1); - return; - } - if (auto mean = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = - getDtypeOrDefault(mean.getContext(), mean.getDtype(), defaultDtype); - visitReductionAlongAllDimsOp(mean, dtype, operands); - return; - } else if (isa(op)) { - Type dtype = operands[0]->getValue().dtype; - visitReductionAlongAllDimsOp(op, dtype, operands); - return; - } else if (isa(op)) { - auto input = operands[0]->getValue(); - visitReductionAlongAllDimsOp(op, input.dtype, operands); - return; - } - - if (auto tensorFloat = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorFloat); - return; - } else if (auto tensorInt = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorInt); - return; - } else if (auto tensorBool = dyn_cast(op)) { - visitScalarToTensorConversionOp(tensorBool); - return; - } - - if (auto tensor = dyn_cast(op)) { - visitAtenTensorOp(tensor); - return; - } - - if (auto zeros = dyn_cast(op)) { - visitConstantTensorAllocOp(zeros, /*dataType=*/{}); - return; - } else if (auto ones = dyn_cast(op)) { - visitConstantTensorAllocOp(ones, /*dataType=*/{}); - return; - } else if (auto emptyMemoryFormat = dyn_cast(op)) { - visitConstantTensorAllocOp(emptyMemoryFormat, - /*dataType=*/{}); - return; - } else if (auto full = dyn_cast(op)) { - visitConstantTensorAllocOp( - full, /*dataType=*/full.getFillValue().getType()); - return; - } else if (auto zerosLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(zerosLike, operands); - return; - } else if (auto onesLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(onesLike, operands); - return; - } else if (auto emptyLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(emptyLike, operands); - return; - } else if (auto fullLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(fullLike, operands); - return; - } else if (auto newZeros = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newZeros, operands); - return; - } else if (auto newOnes = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newOnes, operands); - return; - } else if (auto newEmpty = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmpty, operands); - return; - } else if (auto newEmptyStrided = dyn_cast(op)) { - visitConstantTensorNewLikeOp(newEmptyStrided, - operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto randLike = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(randLike, operands); - return; - } else if (auto toCopy = dyn_cast(op)) { - visitConstantTensorAllocLikeOp(toCopy, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto primsConvertElementType = dyn_cast(op)) { - visitAtenToDtypeLikeOp(primsConvertElementType, - operands); - return; - } - - if (auto toDtypeLayout = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtypeLayout, operands); - return; - } - - if (auto toDtype = dyn_cast(op)) { - visitAtenToDtypeLikeOp(toDtype, operands); - return; - } - - if (auto toOther = dyn_cast(op)) { - visitTypeConversionOp(toOther, operands); - return; - } else if (auto typeAs = dyn_cast(op)) { - visitTypeConversionOp(typeAs, operands); - return; - } - - if (auto cat = dyn_cast(op)) { - visitAtenCatLikeOp(cat, operands); - return; - } else if (auto stack = dyn_cast(op)) { - visitAtenCatLikeOp(stack, operands); - return; - } - - if (auto shapeAsTensor = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - incorporateKnowledge(shapeAsTensor.getResult(), knowledge); - return; - } - - if (auto embedding = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = operands[0]->getValue().dtype; - incorporateKnowledge(embedding.getResult(), knowledge); - return; - } - - if (isa(op)) { - visitAtenEmbeddingBagOp(op); - return; - } - - if (auto softmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(softmaxIntOp, operands); - return; - } - if (auto _softmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_softmaxOp, operands); - return; - } else if (auto _logSoftmaxOp = dyn_cast(op)) { - visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands); - return; - } else if (auto logSoftmaxIntOp = dyn_cast(op)) { - visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); - return; - } - - if (auto numToTensorOp = dyn_cast(op)) { - visitNumToTensorOp(numToTensorOp); - return; - } - - if (isa(op)) { - visitBinaryScalarOp(op, operands); - return; - } - - if (auto scalarImplicit = dyn_cast(op)) { - visitAtenScalarImplicitOp(scalarImplicit, operands); - return; - } - - if (auto vectorNorm = dyn_cast(op)) { - Type defaultDtype = operands[0]->getValue().dtype; - Type dtype = getDtypeOrDefault(vectorNorm.getContext(), vectorNorm.getDtype(), - defaultDtype); - visitReductionAlongDimIntListOp(vectorNorm, vectorNorm.getDim(), - vectorNorm.getKeepdim(), dtype, operands); - return; - } - - if (auto randIntLow = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randIntLow.getDtype(), defaultDtype); - incorporateKnowledge(randIntLow.getResult(), knowledge); - return; - } - - if (auto randInt = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randInt.getDtype(), defaultDtype); - incorporateKnowledge(randInt.getResult(), knowledge); - return; - } - - if (isa(op)) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - incorporateKnowledge(op->getResult(1), knowledge); - return; - } - - if (auto randn = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = - getDtypeOrDefault(op->getContext(), randn.getDtype(), defaultDtype); - incorporateKnowledge(randn.getResult(), knowledge); - return; - } - - // aten.sort produces two Tensor outputs. The first one is the sorted Tensor - // which will have the dtype same as that of the input Tensor, while the last - // Tensor comprises of sorted item's indices corresponding to the input - // Tensor. - if (auto sortOp = dyn_cast(op)) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = input.dtype; - incorporateKnowledge(op->getResult(0), knowledge); - Type i64Type = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - knowledge.dtype = i64Type; - incorporateKnowledge(op->getResult(1), knowledge); - return; - } - - if (auto randnGenerator = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type defaultDtype = Float32Type::get(op->getContext()); - knowledge.dtype = getDtypeOrDefault(op->getContext(), - randnGenerator.getDtype(), defaultDtype); - incorporateKnowledge(randnGenerator.getResult(), knowledge); - return; - } - - if (auto bucketize = dyn_cast(op)) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - bool outInt32; - if (matchPattern(bucketize.getOutInt32(), m_TorchConstantBool(&outInt32)) && - outInt32) { - knowledge.dtype = - IntegerType::get(op->getContext(), 32, IntegerType::Signed); - } else { - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(bucketize.getResult(), knowledge); - return; - } - - // Otherwise, this is an unknown operation, so reset the state. - setAllToEntryStates(results); - return; -} - -void TypeAnalysis::incorporateKnowledge(Value v, - const ValueKnowledge &knowledge) { - auto updatedKnowledge = ValueKnowledge::meet( - knowledge, ValueKnowledge::getPessimisticValueState(v)); - assert(updatedKnowledge.has_value() && "IR has contradictory type!"); - getLatticeElement(v)->join(updatedKnowledge.value()); -} - -void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op, - ArrayRef operands) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - auto input = operands[0]->getValue(); - auto weight = operands[1]->getValue(); - auto bias = operands[2]->getValue(); - switch (bias.optional) { - case OptionalKnowledge::isNone: - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&input, &weight}); - break; - case OptionalKnowledge::notNone: - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), {&input, &weight, &bias}); - break; - case OptionalKnowledge::unKnown: - // When it's unknown, type promotion can't be decided at compile time. - break; - } - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { - auto resultFloatKnowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - resultFloatKnowledge.dtype = Float32Type::get(op->getContext()); - - incorporateKnowledge(op->getResult(0), resultFloatKnowledge); - auto resultIntKnowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - resultIntKnowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - - for (int64_t i = 1, e = op->getNumResults(); i < e; i++) { - incorporateKnowledge(op->getResult(i), resultIntKnowledge); - } - return; -} - -// Arange like ops returns a 1-D tensor of size ceil(end - start). -void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op, - std::optional start, - Value end, - std::optional step, - Value dtype) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - int64_t dtypeInt; - if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) { - knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); - } else if (dtype.getType().isa()) { - // From torch/_torch_docs.py: - // If `dtype` is not given, infer the data type from the other input - // arguments. If any of `start`, `end`, or `step` are floating-point, the - // `dtype` is inferred to be the default dtype, see - // `torch.get_default_dtype`. Otherwise, the `dtype` is inferred to - // be `torch.int64` - if ((start.has_value() && (*start).getType().isa()) || - end.getType().isa() || - (step.has_value() && (*step).getType().isa())) { - // TODO: Should get the dtype from torch.get_default_dtype(). - // For now, use float32 which is the initial default dtype. - knowledge.dtype = Float32Type::get(op->getContext()); - } else - knowledge.dtype = - IntegerType::get(op->getContext(), 64, IntegerType::Signed); - } - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenArangeStartStepOp(AtenArangeStartStepOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), op.getStep(), op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeStartOp(AtenArangeStartOp op) { - visitAtenArangeLikeOpHelper(op, op.getStart(), op.getEnd(), {}, op.getDtype()); -} - -void TypeAnalysis::visitAtenArangeOp(AtenArangeOp op) { - visitAtenArangeLikeOpHelper(op, {}, op.getEnd(), {}, op.getDtype()); -} - -void TypeAnalysis::visitReductionAlongAllDimsOp( - Operation *op, Type dtype, ArrayRef operands) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); -} - -// These ops do caculation along the dims given by the integer list and reduce -// each dim to size one. If \p keepdim is false, the dims are squeezed. -void TypeAnalysis::visitReductionAlongDimIntListOp( - Operation *op, Value dim, Value keepdim, Type dtype, - ArrayRef operands) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitReductionAlongDimIntOp( - Operation *op, Value dim, Value keepdim, Type dtype, - ArrayRef operands, int resNum) { - assert(dim.getType().isa() && "dim must be int type"); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - knowledge.dtype = dtype; - incorporateKnowledge(op->getResult(resNum), knowledge); -} - -template -void TypeAnalysis::visitScalarToTensorConversionOp(OpTy op) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op.getContext()); - Value t = op.getT(); - Value dtype = op.getDtype(); - fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType()); - incorporateKnowledge(op.getResult(), knowledge); -} - -void TypeAnalysis::visitBinaryScalarOp(Operation *op, - ArrayRef operands) { - auto knowledge = - ValueKnowledge::getScalarPessimisticValueState(op->getContext()); - Type resultType = getPromotedResultScalarType( - {op->getOperand(0).getType(), op->getOperand(1).getType()}); - knowledge.setScalarType(resultType); - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitAtenTensorOp(AtenTensorOp op) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op.getContext()); - Value data = op.getData(); - Value dtype = op.getDtype(); - Type type = data.getType(); - while (auto listType = type.dyn_cast()) { - type = listType.getContainedType(); - } - // TODO: Support tensor as the contained type of the list. - // These are the only types handled by fillInDTypeGivenDTypeAndDataType below. - if (!type.isa()) { - incorporateKnowledge(op.getResult(), knowledge); - return; - } - fillInDTypeGivenDTypeAndDataType(knowledge, dtype, type); - incorporateKnowledge(op.getResult(), knowledge); -} - -template -void TypeAnalysis::visitConstantTensorAllocOp(OpTy op, - std::optional dataType) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - if (!dataType) - dataType = Torch::FloatType::get(op->getContext()); - fillInDTypeGivenDTypeAndDataType(knowledge, op.getDtype(), dataType.value()); - incorporateKnowledge(op.getResult(), knowledge); -} - -template -void TypeAnalysis::visitConstantTensorAllocLikeOp( - OpTy op, ArrayRef operands) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.getDtype(), input.dtype); - incorporateKnowledge(op.getResult(), knowledge); -} - -template -void TypeAnalysis::visitConstantTensorNewLikeOp( - OpTy op, ArrayRef operands) { - auto input = operands[0]->getValue(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - fillInDTypeGivenDTypeIntAndInputDType(knowledge, op.getDtype(), input.dtype); - incorporateKnowledge(op.getResult(), knowledge); -} - -// Convert input tensor type to the given `dtype`. -template -void TypeAnalysis::visitAtenToDtypeLikeOp( - OpTy op, ArrayRef operands) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Value dtype = op.getDtype(); - int64_t dtypeInt; - if (matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) - knowledge.dtype = getTypeForDTypeInteger(op->getContext(), dtypeInt); - incorporateKnowledge(op.getResult(), knowledge); -} - -// Convert input tensor type to the same as the other tensor. -template -void TypeAnalysis::visitTypeConversionOp( - OpTy op, ArrayRef operands) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Value other = op.getOther(); - BaseTensorType type = other.getType().cast(); - if (type.hasDtype()) - knowledge.dtype = type.getDtype(); - incorporateKnowledge(op->getResult(0), knowledge); -} - -// `torch.aten.cat` concatenates the given sequence of seq tensors in the given -// dimension. The output has the same sizes as the input for all dimensions -// except the given dimension. -template -void TypeAnalysis::visitAtenCatLikeOp(OpTy op, - ArrayRef operands) { - auto tensorList = op.getTensors(); - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - auto listConstruct = tensorList.template getDefiningOp(); - if (!listConstruct) { - incorporateKnowledge(op.getResult(), knowledge); - return; - } - - SmallVector tensors = llvm::to_vector( - llvm::map_range(listConstruct.getElements(), [&](Value v) -> ValueKnowledge* { - return &getLatticeElement(v)->getValue(); - })); - - knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( - op->getContext(), tensors); - incorporateKnowledge(op->getResult(0), knowledge); -} - -void TypeAnalysis::visitNumToTensorOp(PrimNumToTensorScalarOp op) { - auto knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - // The resulting type from converting a Scalar into a Tensor is different - // if the scalar is part of a tensor operation (such as AtenMulScalar) or - // not. In the former case, the type promotion rules are captured by the - // `getDefaultDtypeForTorchScalar` helper above. The latter case follows the - // rules in - // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h. - // `NumToTensor` falls in the latter case. - Type type = op.getA().getType(); - knowledge.dtype = getBuiltInTypeForTorchScalar(type); - incorporateKnowledge(op.getResult(), knowledge); -} - -// Common template for softmax like ops, eg., log_softmax. -template -void TypeAnalysis::visitAtenSoftmaxLikeOp( - OpTy op, ArrayRef operands) { - auto input = operands[0]->getValue(); - auto dtype = op.getDtype(); - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype); - incorporateKnowledge(op.getResult(), knowledge); -} - -// Common template for softmax like ops, eg., log_softmax.(underscore variant) -template -void TypeAnalysis::visitAten_SoftmaxLikeOp( - OpTy op, ArrayRef operands) { - auto input = operands[0]->getValue(); - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - bool halfToFloat; - if (matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat))) { - knowledge.dtype = - halfToFloat ? Float32Type::get(op->getContext()) : input.dtype; - } - incorporateKnowledge(op.getResult(), knowledge); -} - -void TypeAnalysis::visitAtenScalarImplicitOp( - AtenScalarImplicitOp op, ArrayRef operands) { - auto knowledge = - ValueKnowledge::getScalarPessimisticValueState(op.getContext()); - Type dType = operands[0]->getValue().dtype; - if (dType.isa()) - knowledge.setScalarType(Torch::FloatType::get(op->getContext())); - else if (dType.isa()) - knowledge.setScalarType(Torch::IntType::get(op->getContext())); - incorporateKnowledge(op->getResult(0), knowledge); -} - -// ----------------------------------------------------------------------------- -// Transforms. -// ----------------------------------------------------------------------------- - -// Get a the most refined type compatible with ValueKnowledge, or null if that -// is not possible. -static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { - auto getRefinedTensorType = [](BaseTensorType tensorType, - ValueKnowledge const &knowledge) { - return tensorType - .getWithSizesAndDtype(tensorType.getOptionalSizes(), knowledge.dtype) - .cast(); - }; - if (auto tensorType = v.getType().dyn_cast()) { - const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) - return nullptr; - const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; - return getRefinedTensorType(tensorType, knowledge); - } else if (auto optionalType = v.getType().dyn_cast()) { - const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) - return nullptr; - const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; - if (knowledge.optional == OptionalKnowledge::isNone) - return Torch::NoneType::get(v.getContext()); - else if (knowledge.optional == OptionalKnowledge::notNone) { - auto containedType = optionalType.getContainedType(); - if (auto tensorType = containedType.dyn_cast()) - return getRefinedTensorType(tensorType, knowledge); - else - return containedType; - } - } else if (auto scalarType = v.getType().dyn_cast()) { - const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) - return nullptr; - const ValueKnowledge &knowledge = latticeElement->getValue(); - if (!knowledge.isInitialized) - return nullptr; - if (knowledge.kind == torch_upstream::TypeKind::IntType) - return Torch::IntType::get(v.getContext()); - if (knowledge.kind == torch_upstream::TypeKind::FloatType) - return Torch::FloatType::get(v.getContext()); - } - return nullptr; -} - -// Return true if we can safely change the operands or results of `op`. -// -// The most trivial case is when the op has the AllowsTypeRefinement trait, -// which allows arbitrary refinements. But some other cases are safe too, -// such as when an op has two types that are coupled, but we know that our -// analysis and updating logic will correctly maintain the invariants of the op. -// The `torch.copy.to_tensor` / `torch.copy.to_vtensor` are examples of the -// latter case, since their operand and result types must have the same shape -// and dtype -- we know that our transfer functions and updating logic will do -// the right thing forthose ops. -// -static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) { - return op->hasTrait() || - isa(op); -} - -// Some operations have extra verification logic regarding the relationship -// between the input types and output types. Adding more refined type info to -// the operand might change a valid instruction to be invalid. -static bool operationIsValidWithRefinedType(OpOperand *use, Type newType) { - Operation *op = use->getOwner(); - if (auto uncheckedCast = llvm::dyn_cast(op)) - return uncheckedCast.areCastCompatible(newType, uncheckedCast.getType()); - return true; -} - -static bool isSafeToRefineOperandInPlace(OpOperand *use, Type newOperandType) { - Operation *op = use->getOwner(); - if (!allowsTypeRefinementOrIsSafeToRefine(op)) - return false; - return operationIsValidWithRefinedType(use, newOperandType); -} - -void optimize(func::FuncOp func, DataFlowSolver &solver) { - func.walk([&](Operation *op) { - auto convertValuesToMostRefinedType = [&](ValueRange values, OpBuilder &b) { - for (Value v : values) { - Type refinedType = getMostRefinedStaticType(v, solver); - Type originalType = v.getType(); - // No type? Nothing to do. - if (!refinedType) - continue; - // Type is same as existing one? Nothing to do. - if (refinedType == originalType) - continue; - // If we have an op that allows adding/removing static information from - // this type, then we can rewrite. We make sure to always embed the - // static information in the IR, and insert the minimal number of casts - // needed to do so. - using CreateStaticInfoCastFn = - std::function; - CreateStaticInfoCastFn createStaticInfoDownCast; - CreateStaticInfoCastFn createStaticInfoUpCast; - if (originalType.isa()) { - createStaticInfoDownCast = [&](Location loc, Type newType, - Value v) -> Value { - return b.create(loc, newType, v); - }; - createStaticInfoUpCast = createStaticInfoDownCast; - } else if (originalType.isa()) { - createStaticInfoDownCast = [&](Location loc, Type newType, - Value v) -> Value { - return b.create(loc, newType, v); - }; - createStaticInfoUpCast = [&](Location loc, Type newType, - Value v) -> Value { - return b.create(loc, newType, v); - }; - } - - if (createStaticInfoUpCast) { - assert(createStaticInfoDownCast && - "createStaticInfoDownCast and createStaticInfoUpCast must be " - "defined in pairs"); - // Save off the original uses to avoid iterator invalidation issues - // or other unexpected behavior since we are creating new ops here - // that use the value. - auto originalUses = llvm::to_vector<6>(llvm::map_range( - v.getUses(), [](OpOperand &use) { return &use; })); - OpBuilder b(op->getBlock(), std::next(op->getIterator())); - Value newTypedValue; - // Always make sure that the new static information is reflected in - // the IR, either by updating the type in place, or inserting a static - // info cast. - if (allowsTypeRefinementOrIsSafeToRefine(op)) { - newTypedValue = v; - v.setType(refinedType); - } else { - if (auto derefineOp = llvm::dyn_cast(op)) { - newTypedValue = derefineOp.getOperand(); - } else { - newTypedValue = - createStaticInfoDownCast(op->getLoc(), refinedType, v); - } - } - - Value oldTypedValue; - for (OpOperand *use : originalUses) { - // If the use can be updated to the new type directly, do it! - if (isSafeToRefineOperandInPlace(use, refinedType)) { - use->set(newTypedValue); - continue; - } else if (auto overwriteTensorContents = - dyn_cast( - use->getOwner())) { - // `OverwriteTensorContentsOp` has special handling here because - // it requires that both of its operands always have the same - // shape and dtype. - // - // WARNING: In order to simplify the implementation, the type - // used for both operands is the type of the overwritten tensor. - // A better way of doing this would be to join the two operand - // types to create the most specific type possible and use that - // for both arguments, allowing static sizes to always propagate. - const unsigned overwriterOperandIndex = 0; - const unsigned overwrittenOperandIndex = 1; - unsigned operandNumber = use->getOperandNumber(); - if (operandNumber != overwrittenOperandIndex) - continue; - - Location loc = overwriteTensorContents.getLoc(); - Value overwriterTensor = overwriteTensorContents.getValue(); - Type overwriterTensorType = overwriterTensor.getType(); - Type overwrittenTensorType = newTypedValue.getType() - .dyn_cast() - .getWithValueSemantics(); - if (overwriterTensorType == overwrittenTensorType) - continue; - - { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(overwriteTensorContents); - Value castedOverwriterTensor = b.create( - loc, overwrittenTensorType, overwriterTensor); - overwriteTensorContents.setOperand(overwriterOperandIndex, - castedOverwriterTensor); - } - continue; - } - - // If needed, create a value of the original type to appease users - // that cannot accept the new type. - if (!oldTypedValue) { - if (auto derefineOp = llvm::dyn_cast(op)) { - oldTypedValue = derefineOp.getResult(); - } else { - oldTypedValue = createStaticInfoUpCast( - op->getLoc(), originalType, newTypedValue); - } - } - use->set(oldTypedValue); - } - } - } - }; - - if (auto branch = dyn_cast(op)) { - for (auto ®ion : branch->getRegions()) { - OpBuilder b(region); - convertValuesToMostRefinedType(region.front().getArguments(), b); - } - } - OpBuilder b(op->getBlock(), std::next(op->getIterator())); - convertValuesToMostRefinedType(op->getResults(), b); - }); -} - -namespace { -class RefineTypesPass : public RefineTypesBase { - void runOnOperation() override { - auto func = getOperation(); - DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(); - if (failed(solver.initializeAndRun(func))) - return signalPassFailure(); - optimize(func, solver); - } -}; -} // namespace - -std::unique_ptr> -mlir::torch::Torch::createRefineTypesPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 6e3d054eb..8e6b5888b 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -176,20 +176,23 @@ FailureOr Torch::adjustFunctionArg( return b.create(loc, desiredType, operand).getResult(); } - // !torch.union is the type used for `Scalar` inputs. At - // compile time, such inputs will usually be resolved to an `int` or a `float` - // so we need to derefine to match the library function signature. + // !torch.union or !torch.union is the type used + // for (optional) `Scalar` inputs. At compile time, such inputs will usually + // be resolved to an `int` or a `float` so we need to derefine to match the + // library function signature. if (auto unionType = desiredType.dyn_cast()) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType.isa(); + return containedType + .isa(); })) return b.create(loc, desiredType, operand).getResult(); } - // If the operand is NoneType, then we just need to derefine it to the - // optional type in the function signature. + // Operands with type `!torch.none` correspond to library function inputs with + // types like `!torch.optional<...>` or `!torch.union<..., none>`, so here the + // type is derefined to match the expected type of the library function. if (operandType.isa()) { - assert(desiredType.isa() && + assert(!desiredType.isa() && "Don't expect library functions to have NoneType parameters"); return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index fd58ead00..1a2d3d545 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -8,11 +8,248 @@ //===----------------------------------------------------------------------===// #include "SimplifyAbstractInterpCalculationsUtils.h" +#include "mlir/IR/IRMapping.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +namespace { +class FoldPrimUncheckedCastOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimUncheckedCastOp op, + PatternRewriter &rewriter) const override { + if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { + return rewriter.notifyMatchFailure( + op, "input tensor type is not a valid subtype of result type"); + } + rewriter.replaceOp(op, op.getX()); + return success(); + } +}; +} // namespace + +namespace { +// TODO: Only unroll inside the shape calculation region. +// Maybe do this by only applying patterns and folding greedily on the ops +// inside the region + the shape.calculate op itself? +class FullyUnrollPrimLoopOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimLoopOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + if (!op.isForLike()) + return rewriter.notifyMatchFailure(op, "Loop is not for-like"); + int64_t maxTripCount; + if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) + return rewriter.notifyMatchFailure( + op, "Expected `maxTripCount` to be a constant int"); + ; + SmallVector indices; + for (int64_t i = 0; i < maxTripCount; i++) { + // TODO: Add convenience builder. + indices.push_back(rewriter.create( + loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); + } + Block *beforeBlock = op->getBlock(); + Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); + + SmallVector blocksToMerge; + IRMapping bvm; + // TODO: Helper for region().front() + auto condition = + cast(op.getRegion().front().getTerminator()); + for (int64_t i = 0; i < maxTripCount; i++) { + SmallVector iterArgs; + if (i == 0) { + llvm::append_range(iterArgs, op.getIterArgsInit()); + } else { + llvm::append_range( + iterArgs, llvm::map_range(condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); })); + } + bvm.clear(); + bvm.map(op.getRegion().front().getArgument(0), indices[i]); + bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); + + op.getRegion().cloneInto(afterBlock->getParent(), + afterBlock->getIterator(), bvm); + Block *clonedBlock = bvm.lookup(&op.getRegion().front()); + rewriter.eraseOp(clonedBlock->getTerminator()); + blocksToMerge.push_back(clonedBlock); + } + + blocksToMerge.push_back(afterBlock); + for (Block *block : blocksToMerge) + rewriter.mergeBlocks(block, beforeBlock); + if (maxTripCount == 0) { + rewriter.replaceOp(op, op.getIterArgsInit()); + } else { + rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( + condition.getIterArgs(), + [&](Value v) { return bvm.lookup(v); }))); + } + return success(); + } +}; +} // namespace + +namespace { +class AbstractlyInterpretListOpsWithinABlock + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListConstructOp op, + PatternRewriter &rewriter) const override { + Block *block = op->getBlock(); + auto allUsers = llvm::to_vector<6>(op->getUsers()); + + // Sort the users into program order. + auto getParentInBlock = [&](Operation *op) { + while (op->getBlock() != block) + op = op->getParentOp(); + return op; + }; + // Use a stable sort for deterministic results when users are nested in two + // regions of the same parent op. + llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { + return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); + }); + + // We cannot interpret all ops. So first do a check to see up until which + // point we can interpret. + int numUsersToInterpret = 0; + for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { + Operation *user = allUsers[i]; + // If a user potentially mutates the list, then we require it to be in the + // same block for our simple abstract interpretation to work (we can't, + // for example, handle an "append" operation in a loop or other region). + // However, if the op is read-only, then from the purpose of our abstract + // interpretation, we can handle it effectively as though it was at the + // same position as the corresponding parent op in the block under + // consideration. + if (potentiallyMutatesListOperands(user)) { + if (user->getBlock() != block) + break; + } + } + + // Truncate the list of users to the number of users we're going to + // interpret. + allUsers.resize(numUsersToInterpret); + auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); + + // For each mutating op (which must be in the same block), we save the + // current state of the list as a vector of Value's. These will then + // be converted to PrimListConstructOp's at the correct program points. + SmallVector> listLiterals; + SmallVector runningList; + llvm::append_range(runningList, op->getOperands()); + bool generatedNewLiteral = false; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + if (!append.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenAppendTOp` to not have users"); + if (append.getSelf() == op) { + runningList.push_back(append.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto insert = dyn_cast(user)) { + if (!insert.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `AtenInsertTOp` to not have users"); + int64_t index; + if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); + // The index might be statically out of bounds. + if (index < 0 || index > static_cast(runningList.size())) + return rewriter.notifyMatchFailure( + op, "Index in `AtenInsertTOp` is out of bounds"); + if (insert.getSelf() == op) { + runningList.insert(runningList.begin() + index, insert.getEl()); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + if (auto setItem = dyn_cast(user)) { + if (!setItem.use_empty()) + return rewriter.notifyMatchFailure( + op, "Expected `Aten_SetItemTOp` to not have users"); + std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( + setItem.getIdx(), runningList.size()); + // The index might be statically out of bounds. + if (!indexOpt) + return rewriter.notifyMatchFailure( + op, "Index in `Aten_SetItemTOp` is out of bounds"); + if (setItem.getL() == op) { + runningList[*indexOpt] = setItem.getEl(); + generatedNewLiteral = true; + } + listLiterals.push_back(runningList); + continue; + } + // If this user potentially mutates the list and isn't handled above, then + // we can't abstractly interpret any further. + if (potentiallyMutatesListOperands(user)) + break; + } + + if (!generatedNewLiteral) + return rewriter.notifyMatchFailure(op, "No new literal created"); + + // Rewrite all users to use the appropriate list literals. + Value latestLiteral = rewriter.create( + op->getLoc(), op.getType(), op->getOperands()); + int nextLiteral = 0; + for (Operation *user : usersToInterpret) { + if (auto append = dyn_cast(user)) { + rewriter.setInsertionPoint(append); + latestLiteral = rewriter.create( + append->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (append.getSelf() == op) + rewriter.eraseOp(append); + continue; + } + if (auto insert = dyn_cast(user)) { + rewriter.setInsertionPoint(insert); + latestLiteral = rewriter.create( + insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (insert.getSelf() == op) + rewriter.eraseOp(insert); + continue; + } + if (auto setItem = dyn_cast(user)) { + rewriter.setInsertionPoint(setItem); + latestLiteral = rewriter.create( + setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); + if (setItem.getL() == op) + rewriter.eraseOp(setItem); + continue; + } + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get() == op.getResult()) { + opOperand.set(latestLiteral); + } + } + } + + // Any remaining uses should use the updated value of the latest literal. + rewriter.replaceOp(op, latestLiteral); + return success(); + } +}; +} // namespace + LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, @@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, return success(); } + +void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + +void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h index 9c618d4a2..172d27c00 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.h @@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp, int resultNum, Type newResultType, PatternRewriter &rewriter); +void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns, + MLIRContext *context); +void populateAbstractlyInterpretListOpsWithinABlockPattern( + RewritePatternSet &patterns, MLIRContext *context); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 3c4a334b5..43f2b22a3 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -191,10 +191,17 @@ class SimplifyDtypeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); patterns.insert(context); + PrimIfOp::getCanonicalizationPatterns(patterns, context); + Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); + PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context); + // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. GreedyRewriteConfig config; diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index f8d3651d9..1669be7c4 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -10,7 +10,6 @@ #include "PassDetail.h" #include "SimplifyAbstractInterpCalculationsUtils.h" -#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -19,225 +18,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? -class FullyUnrollPrimLoopOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimLoopOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - if (!op.isForLike()) - return rewriter.notifyMatchFailure(op, "Loop is not for-like"); - int64_t maxTripCount; - if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount))) - return rewriter.notifyMatchFailure( - op, "Expected `maxTripCount` to be a constant int"); - ; - SmallVector indices; - for (int64_t i = 0; i < maxTripCount; i++) { - // TODO: Add convenience builder. - indices.push_back(rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); - } - Block *beforeBlock = op->getBlock(); - Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); - - SmallVector blocksToMerge; - IRMapping bvm; - // TODO: Helper for region().front() - auto condition = - cast(op.getRegion().front().getTerminator()); - for (int64_t i = 0; i < maxTripCount; i++) { - SmallVector iterArgs; - if (i == 0) { - llvm::append_range(iterArgs, op.getIterArgsInit()); - } else { - llvm::append_range( - iterArgs, llvm::map_range(condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); })); - } - bvm.clear(); - bvm.map(op.getRegion().front().getArgument(0), indices[i]); - bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs); - - op.getRegion().cloneInto(afterBlock->getParent(), afterBlock->getIterator(), - bvm); - Block *clonedBlock = bvm.lookup(&op.getRegion().front()); - rewriter.eraseOp(clonedBlock->getTerminator()); - blocksToMerge.push_back(clonedBlock); - } - - blocksToMerge.push_back(afterBlock); - for (Block *block : blocksToMerge) - rewriter.mergeBlocks(block, beforeBlock); - if (maxTripCount == 0) { - rewriter.replaceOp(op, op.getIterArgsInit()); - } else { - rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range( - condition.getIterArgs(), - [&](Value v) { return bvm.lookup(v); }))); - } - return success(); - } -}; -} // namespace - -namespace { -class AbstractlyInterpretListOpsWithinABlock - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimListConstructOp op, - PatternRewriter &rewriter) const override { - Block *block = op->getBlock(); - auto allUsers = llvm::to_vector<6>(op->getUsers()); - - // Sort the users into program order. - auto getParentInBlock = [&](Operation *op) { - while (op->getBlock() != block) - op = op->getParentOp(); - return op; - }; - // Use a stable sort for deterministic results when users are nested in two - // regions of the same parent op. - llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) { - return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs)); - }); - - // We cannot interpret all ops. So first do a check to see up until which - // point we can interpret. - int numUsersToInterpret = 0; - for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) { - Operation *user = allUsers[i]; - // If a user potentially mutates the list, then we require it to be in the - // same block for our simple abstract interpretation to work (we can't, - // for example, handle an "append" operation in a loop or other region). - // However, if the op is read-only, then from the purpose of our abstract - // interpretation, we can handle it effectively as though it was at the - // same position as the corresponding parent op in the block under - // consideration. - if (potentiallyMutatesListOperands(user)) { - if (user->getBlock() != block) - break; - } - } - - // Truncate the list of users to the number of users we're going to - // interpret. - allUsers.resize(numUsersToInterpret); - auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret); - - // For each mutating op (which must be in the same block), we save the - // current state of the list as a vector of Value's. These will then - // be converted to PrimListConstructOp's at the correct program points. - SmallVector> listLiterals; - SmallVector runningList; - llvm::append_range(runningList, op->getOperands()); - bool generatedNewLiteral = false; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - if (!append.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenAppendTOp` to not have users"); - if (append.getSelf() == op) { - runningList.push_back(append.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto insert = dyn_cast(user)) { - if (!insert.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `AtenInsertTOp` to not have users"); - int64_t index; - if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index))) - return rewriter.notifyMatchFailure( - op, "Expected `idx` of `AtenInsertTOp` to be a constant int"); - // The index might be statically out of bounds. - if (index < 0 || index > static_cast(runningList.size())) - return rewriter.notifyMatchFailure( - op, "Index in `AtenInsertTOp` is out of bounds"); - if (insert.getSelf() == op) { - runningList.insert(runningList.begin() + index, insert.getEl()); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - if (auto setItem = dyn_cast(user)) { - if (!setItem.use_empty()) - return rewriter.notifyMatchFailure( - op, "Expected `Aten_SetItemTOp` to not have users"); - std::optional indexOpt = matchLegalConstantIndexIntoListOfSize( - setItem.getIdx(), runningList.size()); - // The index might be statically out of bounds. - if (!indexOpt) - return rewriter.notifyMatchFailure( - op, "Index in `Aten_SetItemTOp` is out of bounds"); - if (setItem.getL() == op) { - runningList[*indexOpt] = setItem.getEl(); - generatedNewLiteral = true; - } - listLiterals.push_back(runningList); - continue; - } - // If this user potentially mutates the list and isn't handled above, then - // we can't abstractly interpret any further. - if (potentiallyMutatesListOperands(user)) - break; - } - - if (!generatedNewLiteral) - return rewriter.notifyMatchFailure(op, "No new literal created"); - - // Rewrite all users to use the appropriate list literals. - Value latestLiteral = rewriter.create( - op->getLoc(), op.getType(), op->getOperands()); - int nextLiteral = 0; - for (Operation *user : usersToInterpret) { - if (auto append = dyn_cast(user)) { - rewriter.setInsertionPoint(append); - latestLiteral = rewriter.create( - append->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (append.getSelf() == op) - rewriter.eraseOp(append); - continue; - } - if (auto insert = dyn_cast(user)) { - rewriter.setInsertionPoint(insert); - latestLiteral = rewriter.create( - insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (insert.getSelf() == op) - rewriter.eraseOp(insert); - continue; - } - if (auto setItem = dyn_cast(user)) { - rewriter.setInsertionPoint(setItem); - latestLiteral = rewriter.create( - setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); - if (setItem.getL() == op) - rewriter.eraseOp(setItem); - continue; - } - for (OpOperand &opOperand : user->getOpOperands()) { - if (opOperand.get() == op.getResult()) { - opOperand.set(latestLiteral); - } - } - } - - // Any remaining uses should use the updated value of the latest literal. - rewriter.replaceOp(op, latestLiteral); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -266,22 +46,6 @@ public: }; } // namespace -namespace { -class FoldPrimUncheckedCastOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(PrimUncheckedCastOp op, - PatternRewriter &rewriter) const override { - if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) { - return rewriter.notifyMatchFailure( - op, "input tensor type is not a valid subtype of result type"); - } - rewriter.replaceOp(op, op.getX()); - return success(); - } -}; -} // namespace - static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -367,11 +131,11 @@ class SimplifyShapeCalculationsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.insert(context); - patterns.insert(context); + populateFullyUnrollPrimLoopOpPattern(patterns, context); + populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context); + populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); - patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 971fbfc8d..7ad7207f7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -12,7 +12,11 @@ from torch import device import torch.jit._shape_functions as upstream_shape_functions from .testing_framework import Invocation, ErrorInvocation, TensorOfShape, LongTensorOfShape, NonZeroDTensorWithDtype, ZeroDTensorWithDtype, check_shape_function, check_dtype_function -from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar +from .library_generator import generate_library, not_present_in_registry, promote_dtypes, get_dtype_of_scalar, is_integer_dtype, is_float_dtype, is_complex_dtype, get_priority_of_dtype, all_integer_dtypes, all_float_dtypes, all_complex_dtypes + +# ============================================================================== +# Shape Functions +# ============================================================================== # TODO: upstream this def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): @@ -79,27 +83,6 @@ def aten〇exp〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32)), - Invocation(NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(ZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64)), - Invocation(ZeroDTensorWithDtype(torch.bfloat16)), - Invocation(ZeroDTensorWithDtype(torch.int64)), - Invocation(ZeroDTensorWithDtype(torch.int32)), - Invocation(ZeroDTensorWithDtype(torch.bool)), -]) -def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: - return self_dtype - else: - return torch.float32 - def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -277,18 +260,6 @@ def aten〇pow〇Tensor_Tensor〡shape(self: List[int], exponent: List[int]) -> def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), other=0), - Invocation(NonZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(NonZeroDTensorWithDtype(torch.float16), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float32), other=0), - Invocation(ZeroDTensorWithDtype(torch.int64), other=0.0), - Invocation(ZeroDTensorWithDtype(torch.float16), other=0.0) -]) -def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: - self_rank, self_dtype = self_rank_dtype - return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) - def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]: return upstream_shape_functions.unary(self) @@ -684,19 +655,6 @@ def aten〇div〇Tensor_mode〡shape(self: List[int], other: List[int], rounding def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float64), NonZeroDTensorWithDtype(torch.float32)), - Invocation(ZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.float64)), - Invocation(NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)), -]) -def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - other_rank, other_dtype = other_rank_dtype - ranks: List[Optional[int]] = [self_rank, other_rank] - dtypes = [self_dtype, other_dtype] - return promote_dtypes(ranks, dtypes) - def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -885,40 +843,6 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) -_convolution_deprecated_kwargs = { - "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], - "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} -@check_dtype_function( - [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), - ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), - TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) -]) -def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: - input_rank, input_dtype = input_rank_dtype - weight_rank, weight_dtype = weight_rank_dtype - assert input_dtype == weight_dtype - assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128] - ranks: List[Optional[int]] = [input_rank, weight_rank] - dtypes = [input_dtype, weight_dtype] - return promote_dtypes(ranks, dtypes) - def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]: return self @@ -998,12 +922,15 @@ def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.native_layer_norm(input, normalized_shape) +# Use CPU because META device results in the wrong behavior +# https://github.com/pytorch/pytorch/issues/100985 +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_shape_function([ Invocation(TensorOfShape(2, 3), None, None, None, None, True, 1e-4, 1e-6), # Training basic case. - Invocation(TensorOfShape(2, 3), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference basic case. + Invocation(TensorOfShape(2, 3, device="cpu"), None, None, TensorOfShape(3, device="cpu"), TensorOfShape(3, device="cpu"), False, 1e-4, 1e-6), # Inference basic case. Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, None, None, True, 1e-4, 1e-6), # Training high-D case. - Invocation(TensorOfShape(2, 3, 4, 5, 6), None, None, TensorOfShape(3), TensorOfShape(3), False, 1e-4, 1e-6), # Inference high-D case. - ErrorInvocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low. + Invocation(TensorOfShape(2, 3, 4, 5, 6, device="cpu"), None, None, TensorOfShape(3, device="cpu"), TensorOfShape(3, device="cpu"), False, 1e-4, 1e-6), # Inference high-D case. + ErrorInvocation(TensorOfShape(2, device="cpu"), None, None, None, None, True, 1e-4, 1e-6) # Dimensionality too low. ]) def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training) @@ -1102,35 +1029,6 @@ def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self -@check_dtype_function([ - Invocation(NonZeroDTensorWithDtype(torch.complex64)), - Invocation(NonZeroDTensorWithDtype(torch.complex128)), - Invocation(NonZeroDTensorWithDtype(torch.float)), - Invocation(NonZeroDTensorWithDtype(torch.double)), - Invocation(NonZeroDTensorWithDtype(torch.bool)), - Invocation(NonZeroDTensorWithDtype(torch.uint8)), - Invocation(NonZeroDTensorWithDtype(torch.int8)), - Invocation(NonZeroDTensorWithDtype(torch.int16)), - Invocation(NonZeroDTensorWithDtype(torch.int32)), - Invocation(NonZeroDTensorWithDtype(torch.int64)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.float16)), - ErrorInvocation(NonZeroDTensorWithDtype(torch.bfloat16)), -]) -def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: - self_rank, self_dtype = self_rank_dtype - if self_dtype == torch.complex64 or self_dtype == torch.complex128: - return self_dtype - elif self_dtype == torch.float: - return torch.complex64 - elif self_dtype == torch.double: - return torch.complex128 - elif self_dtype == torch.bool or self_dtype == torch.uint8 or \ - self_dtype == torch.int8 or self_dtype == torch.int16 or \ - self_dtype == torch.int32 or self_dtype == torch.int64: - return torch.complex64 - else: - assert False, "Unsupported dtype" - class DummyClassType: def __init__(self): pass @@ -1185,6 +1083,855 @@ def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return [self[0], self[1], output_size[0], output_size[1]] +# ============================================================================== +# Dtype Functions +# ============================================================================== + +# All the torch types sorted in decreasing order of priority during type promotion. +_SORTED_TORCH_TYPES = [ + torch.complex128, torch.complex64, + torch.float64, torch.float32, torch.float16, torch.bfloat16, + torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool +] + +def _check_tensors_with_the_same_dtype( + num_of_tensors: Optional[int] = None, + tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_device: Optional[torch.device] = None, + error_types: Optional[set[int]] = None, *args, **kwargs): + """Create invocations where all tensors have the same dtype. + + This function generates invocations with `num_of_tensors` tensors + that all have the same dtype. It creates an invocation for every + possible dtype. For dtypes in `error_types`, the invocations are + error invocations. + + One can also specify the shapes of the tensors. Either `num_of_tensors` + or `tensor_shapes` must be specified whenever this function is called. + + The extra *args and **kwargs arguments are passed to the invocations. + """ + invocations = [] + for type_ in _SORTED_TORCH_TYPES: + tensors = [] + if tensor_shapes is None and num_of_tensors is not None: + tensors = [NonZeroDTensorWithDtype(type_, device=tensor_device)] * num_of_tensors + elif tensor_shapes is not None and num_of_tensors is None: + for tensor_shape in tensor_shapes: + tensors.append(TensorOfShape(*tensor_shape, dtype=type_, device=tensor_device)) + else: + assert False, \ + "Either `num_of_tensors` or `tensor_shapes` must be specified" + + if error_types is not None and type_ in error_types: + invocations.append(ErrorInvocation(*tensors, *args, **kwargs)) + else: + invocations.append(Invocation(*tensors, *args, **kwargs)) + return invocations + +def _check_two_tensor_op( + tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_device: Optional[torch.device] = None, + input_error_types: Optional[set[int]] = None, + output_error_types: Optional[set[int]] = None, **kwargs): + """Generate invocations for basic two-tensor dtype functions. + + This helper function is meant to be used to check dtype functions that + take two tensor operands and either return the promoted result or + return a constant dtype based on the tensor dtypes. + + The testing performed is thorough enough to be able to detect if dtypes + are invalid as inputs or as outputs to the PyTorch op. Invalid dtypes + must be specified in `input_error_types` and `output_error_types` to + ensure the invocations are error invocations. + """ + if tensor_shapes is None: + tensor_shapes = [(1,), (1,)] + shape_1, shape_2 = tensor_shapes + + if input_error_types is not None and output_error_types is not None: + assert len(input_error_types.intersection(output_error_types)) == 0, \ + "An invalid input type implies an invalid output type, " \ + "so there is no need to repeat the type in the `output_error_types` set" + all_error_types = set() + all_error_types |= set() if input_error_types is None else input_error_types + all_error_types |= set() if output_error_types is None else output_error_types + + def check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs): + """Create invocations where one tensor varies its dtype. + + This helper function creates invocations with two tensors where one + tensor varies its dtype while the other one stays constant. The varying + is done for both tensors and the varying is performed over every possible + dtype. + + This function helps identify when a dtype is an invalid input dtype + for dtype functions that do promotion. + """ + # We will only create invocations for dtypes with priorities less than + # or equal to the highest priority valid type. By setting the non-varying + # tensor dtype to be the highest priority valid type, we ensure that + # every promotion results in a valid dtype. This allows the invocations + # to test in isolation assertions on input types. + constant_type = None + constant_type_index = None + for e, type_ in enumerate(_SORTED_TORCH_TYPES): + if type_ not in all_error_types: + constant_type = type_ + constant_type_index = e + break + assert constant_type is not None, \ + "Unable to find a constant type. Make sure the union of " \ + "`input_error_types` and `output_error_types` is not all possible types." + + invocations = [] + for type_ in _SORTED_TORCH_TYPES[constant_type_index:]: + if input_error_types is not None and type_ in input_error_types: + invocation_type = ErrorInvocation + else: + invocation_type = Invocation + invocations += [invocation_type(TensorOfShape(*shape_1, dtype=type_, device=tensor_device), TensorOfShape(*shape_2, dtype=constant_type, device=tensor_device), **kwargs), + invocation_type(TensorOfShape(*shape_1, dtype=constant_type, device=tensor_device), TensorOfShape(*shape_2, dtype=type_, device=tensor_device), **kwargs)] + return invocations + + same_dtype_invocations = _check_tensors_with_the_same_dtype( + tensor_shapes=tensor_shapes, tensor_device=tensor_device, error_types=all_error_types, **kwargs) + + varying_dtype_invocations = \ + check_two_tensors_with_one_varying_dtype_at_a_time(**kwargs) + return same_dtype_invocations + varying_dtype_invocations + +def _get_dtype_of_floating_point_op(input_dtype: int) -> int: + if (is_float_dtype(input_dtype) and input_dtype != torch.float32) \ + or is_complex_dtype(input_dtype): + return input_dtype + return torch.float32 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇reciprocal〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇erf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇softplus〡dtype(self_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return self_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=[0])) +def aten〇frobenius_norm〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int], keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return self_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex128: + return torch.float64 + elif self_dtype == torch.complex64: + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) +def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bernoulli〡dtype(self_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_dtype: Tuple[int, int], generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇ceil〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, max=0)) +def aten〇clamp_max〡dtype(self_rank_dtype: Tuple[int, int], max: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=0)) +def aten〇clamp_min〡dtype(self_rank_dtype: Tuple[int, int], min: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, min=-1, max=1)) +def aten〇clamp〡dtype(self_rank_dtype: Tuple[int, int], min: Optional[Union[int, float]] = None, max: Optional[Union[int, float]] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇clone〡dtype(self_rank_dtype: Tuple[int, int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], value: Union[int, float] = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇copy〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu")) +def aten〇cpu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) +def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇expand_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[2, 2])) +def aten〇expand〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], implicit: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, value=0)) +def aten〇fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1,), ()])) +def aten〇fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_dim: int = 0, end_dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇floor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇gather〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], sparse_grad: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gelu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇gelu〡dtype(self_rank_dtype: Tuple[int, int], approximate: str = "none") -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardsigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇hardswish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(min_val=0.2, max_val=0.5)) +def aten〇hardtanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], min_val: Union[int, float], max_val: Union[int, float]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + if is_integer_dtype(grad_output_dtype): + return torch.float32 + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.uint8, torch.bool})) +def aten〇hardtanh〡dtype(self_rank_dtype: Tuple[int, int], min_val: Union[int, float] = -1, max_val: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype not in [torch.uint8, torch.bool] + return self_dtype + +_index_put_invocations = [ + # same dtype + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES +] + [ + # different dtypes + Invocation(TensorOfShape(3, dtype=dtype), [TensorOfShape(3, dtype=torch.int64)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] + [ + # index dtype + Invocation(TensorOfShape(3, dtype=torch.float32), [TensorOfShape(3, dtype=dtype)], TensorOfShape(3, dtype=torch.float32)) for dtype in _SORTED_TORCH_TYPES +] +@check_dtype_function(_index_put_invocations) +def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_index_put_invocations) +def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64))) +def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor_hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, [TensorOfShape(1, dtype=torch.int64)])) +def aten〇index〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, error_types={*all_integer_dtypes()}, normalized_shape=[1])) +def aten〇layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enable: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_two_tensor_op(negative_slope=0.1, self_is_result=False)) +def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float], self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_two_tensor_op(dim=0, input_dtype=torch.float32) + + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) +def aten〇_log_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), 0)) +def aten〇masked_fill_〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(3,)], None, None, TensorOfShape(1, dtype=torch.bool), TensorOfShape(dtype=torch.float32))) +def aten〇masked_fill〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +# Could not run 'aten::masked_select' with arguments from the 'Meta' backend. +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function( + _check_tensors_with_the_same_dtype(1, None, "cpu", None, NonZeroDTensorWithDtype(torch.bool, device="cpu"))) +def aten〇masked_select〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, start=0, length=1)) +def aten〇narrow〡dtype(self_rank_dtype: Tuple[int, int], dim: int, start: int, length: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇neg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇numpy_T〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, pad=[1, 1])) +def aten〇pad〡dtype(self_rank_dtype: Tuple[int, int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def aten〇permute〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇pow〇Tensor_Tensor〡dtype(self_rank_dtype: Tuple[int, int], exponent_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + exponent_rank, exponent_dtype = exponent_rank_dtype + ranks: List[Optional[int]] = [self_rank, exponent_rank] + dtypes = [self_dtype, exponent_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if promoted_dtype == torch.bool: + return torch.int64 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2) + + [ErrorInvocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64)), + ErrorInvocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32))]) +def aten〇prelu〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇relu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, repeats=[1])) +def aten〇repeat〡dtype(self_rank_dtype: Tuple[int, int], repeats: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1])) +def aten〇_reshape_alias〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shape=[1])) +def aten〇reshape〡dtype(self_rank_dtype: Tuple[int, int], shape: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇resize_〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, shifts=[0], dims=[0])) +def aten〇roll〡dtype(self_rank_dtype: Tuple[int, int], shifts: List[int], dims: List[int] = ()) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, index=0)) +def aten〇select〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(tensor_shapes=[(1, 1), (1,)], dim=0, index=0)) +def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int, index: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op(dim=0)) +def aten〇slice_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + + [Invocation(TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.float64), dim=0, input_dtype=torch.float32), + Invocation(TensorOfShape(1, dtype=torch.float64), TensorOfShape(1, dtype=torch.float32), dim=0, input_dtype=torch.float32)]) +def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.bool: + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇squeeze〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇squeeze〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇tanh_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + output_rank, output_dtype = output_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, output_rank] + dtypes = [grad_output_dtype, output_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, threshold=0, value=0)) +def aten〇threshold〡dtype(self_rank_dtype: Tuple[int, int], threshold: Union[int, float], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(1, tensor_device="meta", device=torch.device("meta"))) +def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1)) +def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇triu〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇_unsafe_view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇unsqueeze〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 4, 8)], output_size=[4, 8], input_size=[1, 1, 2, 3])) +def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + return grad_output_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13])) +def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) +def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇zero_〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function([Invocation(-1), Invocation(-1.0)]) +def prim〇abs〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(3,), (3, 4)], None, None, + TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) + + [Invocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)), + Invocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))]) +def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, grad_output_rank] + dtypes = [self_dtype, grad_output_dtype] + result = promote_dtypes(ranks, dtypes) + if result == torch.bool: + return torch.int64 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype( + None, [(2, 4, 7, 6), (2, 4, 6, 5)], None, None, + [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) + + [ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)), + ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))]) +def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert grad_output_dtype == self_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇any〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.uint8 if self_dtype == torch.uint8 else torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇eq〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇eq〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ge〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇gt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇gt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇ge〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇le〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_or〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇logical_xor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇lt〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇lt〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇le〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇ne〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + return torch.bool + @check_dtype_function([ Invocation(0.0, 0.0), # float, float Invocation(0.0, 0), # float, int @@ -1196,6 +1943,1274 @@ def aten〇add〡dtype(a: Union[int, float], b: Union[int, float]) -> int: dtypes = [get_dtype_of_scalar(a), get_dtype_of_scalar(b)] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) +def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_and〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_or〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇bitwise_xor〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇bmm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + mat2_priority = get_priority_of_dtype(mat2_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return mat2_dtype if mat2_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇div〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(rounding_mode=None)) +def aten〇div〇Tensor_mode〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rounding_mode: Optional[str]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_complex_dtype(promoted_dtype) or \ + (is_float_dtype(promoted_dtype) and promoted_dtype != torch.float32): + return promoted_dtype + else: + return torch.float32 + +@check_dtype_function(_check_two_tensor_op(input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool})) +def aten〇floor_divide〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + assert not is_complex_dtype(other_dtype), "`other` cannot be complex" + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype != torch.bool, "Result dtype for aten.floor_divide bool" + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + + # Different width + [Invocation(TensorOfShape(2, 3, 4, dtype=torch.float64), + TensorOfShape(2, 4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float16), + TensorOfShape(2, 4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(2, 3, 4, dtype=torch.float32), + TensorOfShape(2, 4, 3, dtype=torch.int32))]) +def aten〇matmul〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + other_priority = get_priority_of_dtype(other_dtype) + self_priority = get_priority_of_dtype(self_dtype) + return other_dtype if other_priority < self_priority else self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇maximum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, 3, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32))]) +def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + mat2_rank, mat2_dtype = mat2_rank_dtype + self_rank, self_dtype = self_rank_dtype + + float16_types = [torch.bfloat16, torch.float16] + if self_dtype in float16_types and mat2_dtype in float16_types and self_dtype != mat2_dtype: + return torch.float16 + + ranks: List[Optional[int]] = [self_rank, mat2_rank] + dtypes = [self_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op( + output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) +def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(promoted_dtype) + return promoted_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4,)]) + + # Different width + [Invocation(TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, dtype=torch.float32)), + # Two f16 types + Invocation(TensorOfShape(3, 4, dtype=torch.float16), + TensorOfShape(4, dtype=torch.bfloat16)), + # Different type + Invocation(TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, dtype=torch.int32))]) +def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec_rank, vec_dtype = vec_rank_dtype + ranks: List[Optional[int]] = [self_rank, vec_rank] + dtypes = [self_dtype, vec_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float] = 1) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +# Use CPU because META device results in the wrong behavior +# https://github.com/pytorch/pytorch/issues/100921 +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function(_check_two_tensor_op(tensor_device="cpu", input_error_types={torch.complex64, torch.complex128}, output_error_types={torch.bool, torch.float16}, threshold=0)) +def aten〇threshold_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], threshold: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + assert not is_complex_dtype(grad_output_dtype), "`grad_output` cannot be complex" + assert not is_complex_dtype(self_dtype), "`self` cannot be complex" + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert promoted_dtype not in [torch.bool, torch.float16], \ + "Result dtype for aten.threshold_backward cannot be bool or float16" + return promoted_dtype + +_convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False, "allow_tf32" : False} +# This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + tensor_device="cpu", + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float16, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_kwargs)]) +def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +# This op fails when using meta backend with error: +# Op raised error 'convolution_overrideable not implemented. +# You are likely triggering this with tensor backend other than +# CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL +# to override this function ' but dtype function did not raise any error. +# +# This is similar to https://github.com/pytorch/pytorch/issues/97481 +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], + tensor_device="cpu", + error_types={torch.bool, torch.float16, torch.complex64, torch.complex128}, **_convolution_deprecated_kwargs) + + [ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.bool, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32, device="cpu"), TensorOfShape(1, 1, 1, 1, dtype=torch.float16, device="cpu"), + TensorOfShape(1, dtype=torch.float32, device="cpu"), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert not is_complex_dtype(input_dtype) and input_dtype not in [torch.bool, torch.float16] + assert not is_complex_dtype(weight_dtype) and weight_dtype not in [torch.bool, torch.float16] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16)) +]) +def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), output_padding: List[int] = (0, 0), groups: int = 1, dilation: List[int] = (1, 1)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1,)], **convolution_kwargs) + + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs), + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **convolution_kwargs) +]) +def aten〇convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + +convolution_backward_kwargs = { + "bias_sizes" : [1], "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1, "output_mask" : [True, True, True]} +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1)], + **convolution_backward_kwargs) + + # dtype of first three tensors must be float + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.int32), **convolution_backward_kwargs), + # dtype of first three tensors must be float + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float64), + TensorOfShape(1, 1, 1, 1, dtype=torch.float32), **convolution_backward_kwargs), + # grad_output, input, and weight must have same dtype + Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, 1, 1, 1, dtype=torch.float64), **convolution_backward_kwargs), +]) +def aten〇convolution_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[int, int, int]: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + return grad_output_dtype, grad_output_dtype, grad_output_dtype + +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=2, + tensor_device="cpu", + error_types={torch.bool, torch.bfloat16, torch.float16, torch.float32, torch.float64, + torch.complex64, torch.complex128}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + tensor_device="cpu", + error_types={torch.bool, torch.bfloat16, torch.float16, torch.float32, torch.float64, + torch.complex64, torch.complex128})) +def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype: Optional[Tuple[int, int]] = None, minlength: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_integer_dtype(self_dtype) and self_dtype != torch.bool + if weights_rank_dtype is None: + return torch.int64 + return torch.float64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 4, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + mat1_rank, mat1_dtype = mat1_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + + ranks: List[Optional[int]] = [self_rank, mat1_rank, mat2_rank] + dtypes = [self_dtype, mat1_dtype, mat2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float64), + TensorOfShape(4, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.int32)), + Invocation(TensorOfShape(4, 3, dtype=torch.int32), + TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32))]) +def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + + ranks: List[Optional[int]] = [self_rank, end_rank, weight_rank] + dtypes = [self_dtype, end_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + assert self_dtype != torch.bool + assert tensor1_dtype != torch.bool + assert tensor2_dtype != torch.bool + + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + + # Different width + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float64), + TensorOfShape(3, 3, dtype=torch.float32)), + # Different type + Invocation(TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.int32)), + Invocation(TensorOfShape(3, 3, dtype=torch.int32), + TensorOfShape(3, 3, dtype=torch.float32), + TensorOfShape(3, 3, dtype=torch.float32))]) +def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: Tuple[int, int], tensor2_rank_dtype: Tuple[int, int], value: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + tensor1_rank, tensor1_dtype = tensor1_rank_dtype + tensor2_rank, tensor2_dtype = tensor2_rank_dtype + + ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] + dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] + result = promote_dtypes(ranks, dtypes) + if is_integer_dtype(result): + return torch.float32 + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇add〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇sub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float], alpha: Union[int, float] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇mul〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇div〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) +def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_complex_dtype(self_dtype) + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0)) +def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(exponent)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}, negative_slope=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}, negative_slope=1.0)) +def aten〇leaky_relu〡dtype(self_rank_dtype: Tuple[int, int], negative_slope: Union[int, float] = 0.01) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + ranks: List[Optional[int]] = [self_rank, None] + negative_slope_dtype = get_dtype_of_scalar(negative_slope) + if is_float_dtype(negative_slope_dtype): + assert not is_integer_dtype(self_dtype) + dtypes = [self_dtype, negative_slope_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) +def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool, torch.float16}) + + [ErrorInvocation(TensorOfShape( + 1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int32, device="cpu")), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu")), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.int64, device="cpu")), + ErrorInvocation( + TensorOfShape(1, 1, 1, dtype=torch.float64, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.bfloat16, device="cpu"), TensorOfShape(1, 1, 1, dtype=torch.float16, device="cpu"))]) +def aten〇baddbmm〡dtype(self_rank_dtype: Tuple[int, int], batch1_rank_dtype: Tuple[int, int], batch2_rank_dtype: Tuple[int, int], beta: Union[int, float] = 1, alpha: Union[int, float] = 1) -> int: + batch1_rank, batch1_dtype = batch1_rank_dtype + batch2_rank, batch2_dtype = batch2_rank_dtype + assert batch1_dtype not in [torch.bool, torch.float16] + assert batch2_dtype not in [torch.bool, torch.float16] + assert batch1_dtype == batch2_dtype + ranks: List[Optional[int]] = [batch1_rank, batch2_rank] + dtypes = [batch1_dtype, batch2_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([ + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), NonZeroDTensorWithDtype(torch.int32)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.bfloat16), NonZeroDTensorWithDtype(torch.float16))]) +def aten〇where〇self〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, 0.0)]) +def aten〇where〇Scalar〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other: Union[int, float]) -> int: + if is_integer_dtype(get_dtype_of_scalar(self)) and is_integer_dtype(get_dtype_of_scalar(other)): + return torch.int64 + return torch.float32 + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int64), 0.0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float16), 0), + Invocation(NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.float64), 0.0)]) +def aten〇where〇ScalarOther〡dtype(condition_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], other: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, None] + dtypes = [self_dtype, get_dtype_of_scalar(other)] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function([Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.int16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.int64)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0, NonZeroDTensorWithDtype(torch.float16)), + Invocation(NonZeroDTensorWithDtype(torch.bool), 0.0, NonZeroDTensorWithDtype(torch.float64))]) +def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], self: Union[int, float], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [None, other_rank] + dtypes = [get_dtype_of_scalar(self), other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int32), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.float64), # target must be int64 + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), TensorOfShape(2, dtype=torch.int64), # self and weight must have same dtype + TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.int32), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.int32), reduction=0, ignore_index=0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), TensorOfShape(2, dtype=torch.int64), # self and weight must be float + TensorOfShape(3, dtype=torch.complex64), reduction=0, ignore_index=0)]) +def aten〇nll_loss_forward〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + assert target_dtype == torch.int64 + return self_dtype, self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float64), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.float32), [3], TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), eps=0.0), + # Input must be float or complex + ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32), [3], TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex64), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + Invocation(TensorOfShape(2, 3, dtype=torch.complex128), [3], TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), eps=0.0), + ]) +def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normalized_shape: List[int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + result_dtype = input_dtype + if input_dtype == torch.complex64: + result_dtype = torch.float32 + if input_dtype == torch.complex128: + result_dtype = torch.float64 + return input_dtype, input_dtype, result_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + # Tensors with different dtype + Invocation(TensorOfShape(3, 3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float64), + TensorOfShape(3, dtype=torch.float32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), + TensorOfShape(3, dtype=torch.float64), training=False, momentum=0.0, eps=0.0), + # Non-float tensors + Invocation(TensorOfShape(3, 3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), TensorOfShape(3, dtype=torch.int32), + TensorOfShape(3, dtype=torch.int32), training=False, momentum=0.0, eps=0.0), + Invocation(TensorOfShape(3, 3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), TensorOfShape(3, dtype=torch.complex64), + TensorOfShape(3, dtype=torch.complex64), training=False, momentum=0.0, eps=0.0), + ]) +def aten〇native_batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + result_dtype = input_dtype + if is_integer_dtype(input_dtype): + result_dtype = torch.float32 + return input_dtype, input_dtype, result_dtype + +@check_dtype_function([Invocation(end=0, dtype=None), # No floats + Invocation(end=0.0, dtype=None), # One float + ErrorInvocation(end=0, dtype=torch.complex64), # Dtype specified + Invocation(end=0, dtype=torch.float16), # Dtype specified + Invocation(end=0, dtype=torch.int16)]) # Dtype specified +def aten〇arange〡dtype(end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, dtype=None), # No floats + Invocation(start=0.0, end=10, dtype=None), # One float + Invocation(start=0, end=10.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start〡dtype(start: Union[int, float], end: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)): + return torch.float32 + return torch.int64 + +@check_dtype_function([Invocation(start=0, end=10, step=1, dtype=None), # No floats + Invocation(start=0.0, end=10, step=1, dtype=None), # One float + Invocation(start=0, end=10.0, step=1, dtype=None), # One float + Invocation(start=0, end=10, step=1.0, dtype=None), # One float + ErrorInvocation(start=0, end=10, step=1, dtype=torch.complex64), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.float16), # Dtype specified + Invocation(start=0, end=10, step=1, dtype=torch.int16)]) # Dtype specified +def aten〇arange〇start_step〡dtype(start: Union[int, float], end: Union[int, float], step: Union[int, float] = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + assert not is_complex_dtype(dtype) + return dtype + if is_float_dtype(get_dtype_of_scalar(start)) or \ + is_float_dtype(get_dtype_of_scalar(end)) or \ + is_float_dtype(get_dtype_of_scalar(step)): + return torch.float32 + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇sum〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64)) +def aten〇sum〇dim_IntList〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + return aten〇sum〡dtype(self_rank_dtype, dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, dim=None) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dim=None, dtype=torch.int32)]) +def aten〇mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + result = aten〇sum〡dtype(self_rank_dtype, dtype) + assert not is_integer_dtype(result) + return result + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇argmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇max〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇max〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: + return aten〇max〡dtype(self_rank_dtype), torch.int64 + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇mean〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None) -> int: + return aten〇mean〇dim〡dtype(self_rank_dtype, dim=None, dtype=dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.complex64: + return torch.float32 + if self_dtype == torch.complex128: + return torch.float64 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇std〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇std〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=None)) +def aten〇var〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> int: + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: + return aten〇std〡dtype(inp_rank_dtype) + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Union[int, float] = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + +@check_dtype_function([Invocation(0.0), + Invocation(0.0, dtype=torch.int32), + Invocation(0.0, dtype=torch.float16), + Invocation(0.0, dtype=torch.complex64)]) +def aten〇tensor〇float〡dtype(t: float, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.float32 + return dtype + +@check_dtype_function([Invocation(0), + Invocation(0, dtype=torch.int32), + Invocation(0, dtype=torch.float16), + Invocation(0, dtype=torch.complex64)]) +def aten〇tensor〇int〡dtype(t: int, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.int64 + return dtype + +@check_dtype_function([Invocation(True), + Invocation(True, dtype=torch.int32), + Invocation(True, dtype=torch.float16), + Invocation(True, dtype=torch.complex64)]) +def aten〇tensor〇bool〡dtype(t: bool, dtype: Optional[int] = None, device: Optional[device] = None, requires_grad: bool = False) -> int: + if dtype is None: + return torch.bool + return dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇zeros〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇ones〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1]), + Invocation([1], dtype=torch.int32), + Invocation([1], dtype=torch.float16), + Invocation([1], dtype=torch.complex64)]) +def aten〇empty〇memory_format〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + return torch.float32 if dtype is None else dtype + +@check_dtype_function([Invocation([1], 0.0), + Invocation([1], 0), + Invocation([1], 0.0, dtype=torch.int32), + Invocation([1], 0.0, dtype=torch.float16), + Invocation([1], 0.0, dtype=torch.complex64)]) +def aten〇full〡dtype(size: List[int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is not None: + return dtype + fill_value_dtype = get_dtype_of_scalar(fill_value) + if is_float_dtype(fill_value_dtype): + return torch.float32 + return fill_value_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇zeros_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇ones_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇empty_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, fill_value=0.0, dtype=torch.complex64)) +def aten〇full_like〡dtype(self_rank_dtype: Tuple[int, int], fill_value: Union[int, float], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_zeros〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_ones〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], dtype=torch.complex64)) +def aten〇new_empty〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1]) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1], stride=[1], dtype=torch.complex64)) +def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇rand_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes()) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes() + all_float_dtypes() + all_complex_dtypes(), dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇randn_like〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + result_dtype = self_dtype if dtype is None else dtype + assert not is_integer_dtype(result_dtype) + return result_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int: + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64)) +def aten〇to〇dtype_layout〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype if dtype is None else dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.complex64)) +def aten〇to〇device〡dtype(self_rank_dtype: Tuple[int, int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + return dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + return other_dtype + +@check_dtype_function([Invocation(low=0, high=10, size=[1]), + Invocation(low=0, high=10, size=[1], dtype=torch.float32), + Invocation(low=0, high=10, size=[1], dtype=torch.int32), + ErrorInvocation(low=0, high=10, size=[1], dtype=torch.complex64)]) +def aten〇randint〇low〡dtype(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.int64 + assert not is_complex_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1]), + Invocation(size=[1], dtype=torch.float32), + ErrorInvocation(size=[1], dtype=torch.int32), + Invocation(size=[1], dtype=torch.complex64)]) +def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function([Invocation(size=[1], generator=None), + Invocation(size=[1], generator=None, dtype=torch.float32), + ErrorInvocation(size=[1], generator=None, dtype=torch.int32), + Invocation(size=[1], generator=None, dtype=torch.complex64)]) +def aten〇randn〇generator〡dtype(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + assert not is_integer_dtype(dtype) + return dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]] = None, correction: Optional[Union[int, float]] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types=all_integer_dtypes())) +def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = True) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if self_dtype == torch.complex64: + return torch.float32, self_dtype + if self_dtype == torch.complex128: + return torch.float64, self_dtype + return self_dtype, self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + if is_integer_dtype(promoted_dtype): + return torch.float32 + return promoted_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + +@check_dtype_function(_check_two_tensor_op()) +def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + return input_dtype + +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.int64 + +# Does not work on meta backend +# TODO: This should be fixed by switching to FakeTensor instead of Meta tensor +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[()], tensor_device="cpu", error_types=all_complex_dtypes())) +def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: + a_rank, a_dtype = a_rank_dtype + assert not is_complex_dtype(a_dtype) + if is_float_dtype(a_dtype): + return torch.float64 + if is_integer_dtype(a_dtype) and a_dtype != torch.bool: + return torch.int64 + if a_dtype == torch.bool: + return torch.bool + assert False, "Unexpected dtype!" + +@check_dtype_function([Invocation(0), Invocation(0.0)]) +def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int: + return get_dtype_of_scalar(a) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), + dim=0, half_to_float=True)) +def aten〇_log_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_float: bool) -> int: + self_rank, self_dtype = self_rank_dtype + if half_to_float: + assert self_dtype == torch.float16 + return torch.float32 + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) +def aten〇log_softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if dtype is None: + return self_dtype + return dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> int: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇_embedding_bag〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights_rank_dtype: Optional[Tuple[int, int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +# TODO: to test these functions, we need to be able to specify the tensor contents used in each invocation +def aten〇embedding_bag〇padding_idx〡dtype(weight_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], offsets_rank_dtype: Tuple[int, int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights_rank_dtype: Optional[Tuple[int, int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[int, int, int, int]: + weight_rank, weight_dtype = weight_rank_dtype + return weight_dtype, torch.int64, torch.int64, torch.int64 + +@check_dtype_function(_check_two_tensor_op(out_int32=True) + _check_two_tensor_op(out_int32=False)) +def aten〇bucketize〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], boundaries_rank_dtype: Tuple[int, int], out_int32: bool = False, right: bool = False) -> int: + if out_int32: + return torch.int32 + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dimensions=[])) +def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int]) -> int: + a_rank, a_dtype = a_rank_dtype + return a_dtype + +# ============================================================================== +# Main +# ============================================================================== + def _maybe_import_op_extensions(args: argparse.Namespace): extension_string = str.strip(args.pytorch_op_extensions) if len(extension_string) > 0: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index f87a7019d..3cfc4a24a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -14,6 +14,55 @@ from torch_mlir.passmanager import PassManager from .registry import Registry +def all_integer_dtypes() -> List[int]: + return [torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] + +def is_integer_dtype(dtype: int) -> bool: + return dtype in all_integer_dtypes() + +def all_complex_dtypes() -> List[int]: + return [torch.complex64, torch.complex128] + +def is_complex_dtype(dtype: int) -> bool: + return dtype in all_complex_dtypes() + +def all_float_dtypes() -> List[int]: + return [torch.float16, torch.bfloat16, torch.float32, torch.float64] + +def is_float_dtype(dtype: int) -> bool: + return dtype in all_float_dtypes() + +def get_priority_of_dtype(dtype: int) -> int: + # If a loop is used to iterate over a list of sorted dtypes, TorchScript + # produces a loop with INT64_MAX max trip count, which causes problems + # during the loop unrolling that takes place when simplifying the dtype + # functions. Therefore, here we resort to `if`s. + if dtype == torch.bool: + return 0 + if dtype == torch.uint8: + return 1 + if dtype == torch.int8: + return 2 + if dtype == torch.int16: + return 3 + if dtype == torch.int32: + return 4 + if dtype == torch.int64: + return 5 + if dtype == torch.bfloat16: + return 6 + if dtype == torch.float16: + return 7 + if dtype == torch.float32: + return 8 + if dtype == torch.float64: + return 9 + if dtype == torch.complex64: + return 10 + if dtype == torch.complex128: + return 11 + assert False, "Cannot determine priority of dtype" + def get_dtype_of_scalar(scalar: Union[int, float]) -> int: # This is hacky. `NumToTensor` is the only PyTorch op for scalars # that when `jit.script`ed converts a float scalar to a tensor diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py index efd270b78..6c2a81757 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/testing_framework.py @@ -60,33 +60,32 @@ class TensorOfShape: This class also tracks a dtype of the tensor, since some ops require a specific dtype. """ - def __init__(self, *shape: int, dtype: torch.dtype = torch.float32): + def __init__(self, *shape: int, dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None): self.shape = list(shape) self.dtype = dtype + self.device = "meta" if device is None else device def __repr__(self): args_str = ", ".join(repr(x) for x in self.shape) - if self.dtype is torch.float32: - return f"TensorOfShape({args_str})" - else: - return f"TensorOfShape({args_str}, dtype={self.dtype})" + return f"TensorOfShape({args_str}, dtype={self.dtype}, device={self.device})" def LongTensorOfShape(*args, **kwargs): """Helper for indicating a TensorOfShape with integer type.""" return TensorOfShape(*args, **kwargs, dtype=torch.long) -def NonZeroDTensorWithDtype(dtype): +def NonZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None): """Helper for indicating a non-zero dim tensor with custom type.""" - return TensorOfShape(1, dtype=dtype) + return TensorOfShape(1, dtype=dtype, device=device) -def ZeroDTensorWithDtype(dtype): +def ZeroDTensorWithDtype(dtype, device: Optional[torch.device] = None): """Helper for indicating a zero dim tensor with custom type.""" - return TensorOfShape(dtype=dtype) + return TensorOfShape(dtype=dtype, device=device) def _recursively_transform_tensor_args( o: Any, tensor_transformer: Callable[[TensorOfShape], Any]) -> Any: """Replace `TensorOfShape` with the result of `tensor_transformer`""" - if o is None or isinstance(o, (float, int)): + if o is None or isinstance(o, (float, int, str)): return o if isinstance(o, TensorOfShape): return tensor_transformer(o) @@ -146,7 +145,7 @@ class Invocation: def to_real_op_args(self): """Gets positional arguments appropriate for the real op.""" - tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype) + tensor_transformer = lambda o: torch.ones(o.shape, dtype=o.dtype).to(o.device) return _recursively_transform_tensor_args(self.args, tensor_transformer) def __repr__(self) -> str: @@ -258,6 +257,15 @@ def check_shape_function(invocations: List[Invocation]): return f return decorator +@torch.jit.script +def _convert_dtype_to_int(dtype: torch.dtype) -> int: + """Convert a PyTorch `dtype` into its underlying `int` representation. + + This works because in TorchScript there is no special type for `dtypes`; + they are simply `int`s. + """ + return dtype + def check_dtype_function(invocations: List[Invocation]): """Decorator that automatically tests a dtype function. @@ -281,7 +289,12 @@ def check_dtype_function(invocations: List[Invocation]): golden_dtype = torch.tensor([]).to(type(golden_result)).dtype else: raise ValueError(f"Unhandled return type {type(golden_result)}") - if result_dtype != golden_dtype: + # Some dtype funtions have default `dtype` parameters, which are + # represented as `int` values in the registry. In order to + # support returning the default `int` value, the comparisons of + # the result and golden dtypes are done using their underlying + # `int` representation. + if _convert_dtype_to_int(result_dtype) != _convert_dtype_to_int(golden_dtype): _report(f, invocation, f"Expected result dtype {golden_dtype}, got {result_dtype}") return f return decorator diff --git a/test/Dialect/Torch/refine-types-branch.mlir b/test/Dialect/Torch/refine-types-branch.mlir deleted file mode 100644 index 87ff96576..000000000 --- a/test/Dialect/Torch/refine-types-branch.mlir +++ /dev/null @@ -1,153 +0,0 @@ -// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s - -// ----- - -// CHECK-LABEL: func.func @prim.if$branch_merge_type_tensor( -// CHECK-SAME: %[[PRED:.*]]: !torch.bool, -// CHECK-SAME: %[[T1:.*]]: !torch.tensor, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor) -> !torch.bool { -// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional) { -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T1]] : !torch.tensor to !torch.optional -// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional -// CHECK: } else { -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T2]] : !torch.tensor to !torch.optional -// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional -// CHECK: } -// CHECK: %[[REFINED:.*]] = torch.prim.unchecked_cast %[[MERGED:.*]] : !torch.optional -> !torch.tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[RET:.*]] = torch.aten.__isnot__ %[[REFINED]], %[[NONE]] : !torch.tensor, !torch.none -> !torch.bool -// CHECK: return %[[RET]] : !torch.bool - -func.func @prim.if$branch_merge_type_tensor(%pred: !torch.bool, %t0: !torch.tensor, %t1: !torch.tensor) -> !torch.bool { - %res = torch.prim.If %pred -> (!torch.optional) { - %optional0 = torch.derefine %t0: !torch.tensor to !torch.optional - torch.prim.If.yield %optional0: !torch.optional - } else { - %optional1 = torch.derefine %t1: !torch.tensor to !torch.optional - torch.prim.If.yield %optional1: !torch.optional - } - %none = torch.constant.none - %cmp = torch.aten.__isnot__ %res, %none : !torch.optional, !torch.none -> !torch.bool - return %cmp : !torch.bool -} - -// ----- - -// CHECK-LABEL: func.func @prim.if$branch_merge_type_optional( -// CHECK-SAME: %[[PRED:.*]]: !torch.bool, -// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.optional { -// CHECK: %[[MERGED:.*]] = torch.prim.If %[[PRED]] -> (!torch.optional) { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional -// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional -// CHECK: } else { -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[T]] : !torch.tensor to !torch.optional -// CHECK: torch.prim.If.yield %[[OPTIONAL]] : !torch.optional -// CHECK: } -// CHECK: return %[[MERGED:.*]] : !torch.optional - -func.func @prim.if$branch_merge_type_optional(%pred: !torch.bool, %t1: !torch.tensor) -> !torch.optional { - %res = torch.prim.If %pred -> (!torch.optional) { - %none = torch.constant.none - %optional0 = torch.derefine %none: !torch.none to !torch.optional - torch.prim.If.yield %optional0: !torch.optional - } else { - %optional1 = torch.derefine %t1: !torch.tensor to !torch.optional - torch.prim.If.yield %optional1: !torch.optional - } - return %res: !torch.optional -} - -// ----- - -// CHECK-LABEL: func.func @prim.if$refined_type_conflicting( -// CHECK-SAME: %[[NONE:.*]]: !torch.none) -> !torch.tensor { -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional -// CHECK: %[[NOT_NONE:.*]] = torch.aten.__isnot__ %[[NONE]], %[[NONE]] : !torch.none, !torch.none -> !torch.bool -// CHECK: %[[PRED:.*]] = torch.prim.If %[[NOT_NONE]] -> (!torch.tensor) { -// CHECK: %[[T:.*]] = torch.prim.unchecked_cast %[[OPTIONAL]] : !torch.optional -> !torch.tensor -// CHECK: torch.prim.If.yield %[[T]] : !torch.tensor -// CHECK: } else { -// CHECK: %[[LITERAL:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<3x5xf32>) : !torch.tensor -// CHECK: torch.prim.If.yield %[[LITERAL]] : !torch.tensor -// CHECK: } -// CHECK: return %[[PRED:.*]] : !torch.tensor - -func.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.tensor { - %optional = torch.derefine %none: !torch.none to !torch.optional - %pred = torch.aten.__isnot__ %optional, %none : !torch.optional, !torch.none -> !torch.bool - %res = torch.prim.If %pred -> (!torch.tensor) { - %t = torch.prim.unchecked_cast %optional: !torch.optional -> !torch.tensor - torch.prim.If.yield %t: !torch.tensor - } else { - %t_cst = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor - torch.prim.If.yield %t_cst: !torch.tensor - } - return %res: !torch.tensor -} - -// ----- - -// CHECK-LABEL: func.func @prim.loop$region_arg_to_internal( -// CHECK-SAME: %[[ARG_NONE:.*]]: !torch.none) -> !torch.optional { -// CHECK: %[[INT10:.*]] = torch.constant.int 10 -// CHECK: %[[INDV:.*]] = torch.constant.int 0 -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[ARG_NONE]] : !torch.none to !torch.optional -// CHECK: %[[LOOP_RET:.*]] = torch.prim.Loop %[[INT10]], %[[TRUE]], init(%[[OPTIONAL]]) { -// CHECK: ^bb0(%[[INDV:.*]]: !torch.int, %[[IT:.*]]: !torch.optional): -// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[IT]] : !torch.optional -> !torch.none -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional -// CHECK: %[[COND:.*]] = torch.aten.__isnot__ %[[NONE]], %[[ARG_NONE]] : !torch.none, !torch.none -> !torch.bool -// CHECK: torch.prim.Loop.condition %[[COND]], iter(%[[OPTIONAL]] : !torch.optional) -// CHECK: } : (!torch.int, !torch.bool, !torch.optional) -> !torch.optional -// CHECK: %[[NONE:.*]] = torch.prim.unchecked_cast %[[LOOP_RET:.*]] : !torch.optional -> !torch.none -// CHECK: %[[OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional -// CHECK: return %[[OPTIONAL]] : !torch.optional - -func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.optional { - %int10 = torch.constant.int 10 - %int0 = torch.constant.int 0 - %true = torch.constant.bool true - %optional = torch.derefine %none: !torch.none to !torch.optional - %ret = torch.prim.Loop %int10, %true, init(%optional) { - ^bb0(%arg2: !torch.int, %arg3: !torch.optional): // no predecessors - %cond = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool - torch.prim.Loop.condition %cond, iter(%arg3: !torch.optional) - } : (!torch.int, !torch.bool, !torch.optional) -> (!torch.optional) - return %ret: !torch.optional -} - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - cf.br ^bb1(%cast: !torch.vtensor) -^bb1(%arg1: !torch.vtensor): - %1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: func.func private @callee -// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -func.func @f() { - builtin.module { - func.func private @callee(%arg0: !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor - return - } - func.func @caller(%arg0: !torch.vtensor<*,f32>) { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - call @callee(%cast) : (!torch.vtensor) -> () - return - } - } - return -} diff --git a/test/Dialect/Torch/refine-types-ops.mlir b/test/Dialect/Torch/refine-types-ops.mlir deleted file mode 100644 index 3c90de228..000000000 --- a/test/Dialect/Torch/refine-types-ops.mlir +++ /dev/null @@ -1,364 +0,0 @@ -// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s - -// This file is for tests for individual ops that require a new transfer -// function (i.e. new code called from visitOperation). - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$int64_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.int, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$int64_dtype(%start: !torch.int, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$float32_dtype( -// CHECK-SAME: %[[START:.*]]: !torch.float, -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange.start -// CHECK-SAME: %[[START]], %[[END]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$float32_dtype(%start: !torch.float, %end: !torch.int) -> !torch.vtensor { - %none = torch.constant.none - %ret = torch.aten.arange.start %start, %end, %none, %none, %none, %none: !torch.float, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.arange.start$specified_dtype( -// CHECK-SAME: %[[END:.*]]: !torch.int) -> !torch.vtensor { -// CHECK: %[[CST6:.*]] = torch.constant.int 6 -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T:.*]] = torch.aten.arange -// CHECK-SAME: %[[END]], %[[CST6]], %[[NONE]], %[[NONE]], %[[NONE]] : -// CHECK-SAME: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[T]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RET]] : !torch.vtensor -func.func @aten.arange.start$specified_dtype(%end: !torch.int) -> !torch.vtensor { - %int6 = torch.constant.int 6 - %none = torch.constant.none - %ret = torch.aten.arange %end, %int6, %none, %none, %none: !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.linear( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>, -// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor { -// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor { - %1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.sum.dim_IntList( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,si64>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[DIMLIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT_NEG1]] -// CHECK-SAME: : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.sum.dim_IntList %[[T]], %[[DIMLIST]], %[[FALSE]], %[[NONE]] -// CHECK-SAME: : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.vtensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,si64> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.sum.dim_IntList(%t: !torch.vtensor<*,si64>) -> !torch.vtensor { - %false = torch.constant.bool false - %none = torch.constant.none - %int0 = torch.constant.int 0 - %int-1 = torch.constant.int -1 - %dimList = torch.prim.ListConstruct %int0, %int-1 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.sum.dim_IntList %t, %dimList, %false, %none : !torch.vtensor<*,si64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.any.dim( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT_NEG1:.*]] = torch.constant.int -1 -// CHECK: %[[RET:.*]] = torch.aten.any.dim %[[T]], %[[INT_NEG1]], %[[FALSE]] : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any.dim(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %false = torch.constant.bool false - %int-1 = torch.constant.int -1 - %ret = torch.aten.any.dim %t, %int-1, %false : !torch.vtensor<*,i1>, !torch.int, !torch.bool -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @aten.any( -// CHECK-SAME: %[[T:.*]]: !torch.vtensor<*,i1>) -> !torch.vtensor { -// CHECK: %[[RET:.*]] = torch.aten.any %[[T]] : !torch.vtensor<*,i1> -> !torch.vtensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.vtensor<*,i1> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @aten.any(%t: !torch.vtensor<*,i1>) -> !torch.vtensor { - %ret = torch.aten.any %t: !torch.vtensor<*,i1> -> !torch.vtensor - return %ret : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.zeros( -// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ZEROS]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %int2 = torch.constant.int 2 - %sizesList = torch.prim.ListConstruct %dim0, %int2 : (!torch.int, !torch.int) -> !torch.list - %ret = torch.aten.zeros %sizesList, %none, %none, %none, %none : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.type_as( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, -// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor { - %ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.cat$promote_type( -// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>, -// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list -// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor { - %int1 = torch.constant.int 1 - %tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list - %ret = torch.aten.cat %tensorList, %int1 : !torch.list, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor { -// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor { - %ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.embedding( -// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>, -// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[PADDING_IDX:.*]] = torch.constant.int 1 -// CHECK: %[[RET:.*]] = torch.aten.embedding %[[INPUT]], %[[INDEXES]], %[[PADDING_IDX]], %[[FALSE]], %[[FALSE]] : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indices: !torch.tensor<[2,3], si64>) -> !torch.tensor { - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %ret = torch.aten.embedding %weight, %indices, %int1, %false, %false : !torch.tensor<[104,512],f32>, !torch.tensor<[2,3], si64>, !torch.int, !torch.bool, !torch.bool -> !torch.tensor - return %ret: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor.float$specified_dtype( -// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[CST11:.*]] = torch.constant.int 11 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,i1> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor { - %none = torch.constant.none - %int11 = torch.constant.int 11 - %false = torch.constant.bool false - %ret = torch.aten.tensor.float %t, %int11, %none, %false : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.none -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %none = torch.constant.none - %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int$specified_dtype( -// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, -// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[DTYPE:.*]] = torch.constant.int 4 -// CHECK: %[[SOFTMAX:.*]] = torch.aten.softmax.int %[[T]], %[[DIM]], %[[DTYPE]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[RET:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[RET]] : !torch.tensor -func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor { - %int4 = torch.constant.int 4 - %ret = torch.aten.softmax.int %t, %dim, %int4: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector( -// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>, -// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor { -// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor { - %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.to.dtype( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor -// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype -// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -// CHECK-SAME: -> !torch.tensor<*,si64> -// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK-NEXT: return %[[RES]] : !torch.tensor -func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ - %none = torch.constant.none - %false = torch.constant.bool false - %int4 = torch.constant.int 4 - %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor - return %0 : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar( -// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { -// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor { - %0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor - return %0: !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]] -// CHECK-SAME: : !torch.list>, !torch.none, !torch.none, !torch.bool -// CHECK-SAME: -> !torch.tensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %none, %none, %false : !torch.list>, !torch.none, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype( -// CHECK-SAME: %[[DATA:.*]]: !torch.list>) -> !torch.tensor { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[INT4:.*]] = torch.constant.int 4 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor -// CHECK: return %[[CAST]] : !torch.tensor -func.func @torch.aten.tensor$specified_dtype(%t: !torch.list>) -> !torch.tensor { - %none = torch.constant.none - %int4 = torch.constant.int 4 - %false = torch.constant.bool false - %ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list>, !torch.int, !torch.none, !torch.bool -> !torch.tensor - return %ret : !torch.tensor -} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir deleted file mode 100644 index 50d96b08e..000000000 --- a/test/Dialect/Torch/refine-types.mlir +++ /dev/null @@ -1,238 +0,0 @@ -// RUN: torch-mlir-opt -torch-refine-types -split-input-file %s | FileCheck %s - -// This file tests the structural logic of the pass. This is for testing logic -// that does not scale with the number of ops supported, such as the core -// propagation logic, rewriting, etc. -// Code for testing transfer functions for new ops (which is most changes) -// should go in refine-types-ops.mlir. - -// ----- -// CHECK-LABEL: func.func @basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @keep_existing_shape_information( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> -// CHECK: return %[[COS]] : !torch.vtensor<[2],f32> -func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32> - return %1 : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @propagate_through_multiple_ops( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[COS3]] : !torch.vtensor -func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - %3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor - return %3 : !torch.vtensor -} - -// ----- -// Check rewriting logic in case of mixes of users that do/don't allow type -// refinement. -// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor -func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - return %1, %1 : !torch.vtensor, !torch.vtensor -} - -// ----- - -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static( -// CHECK-SAME: %[[STATIC:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[DYNAMIC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[DYNAMIC_COPY:.*]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[CAST2:.*]] = torch.tensor_static_info_cast %[[CAST:.*]] : !torch.vtensor<*,f32> to !torch.vtensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[CAST2]] overwrites %[[STATIC_COPY:.*]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[2],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %static_copy = torch.copy.to_tensor %static_no_type : !torch.tensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - torch.overwrite.tensor.contents %dynamic_no_type overwrites %static_copy : !torch.vtensor, !torch.tensor - %static_value_copy = torch.copy.to_vtensor %static_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %static_value_copy : !torch.vtensor to !torch.vtensor<[2],f32> - return %result : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[2],f32> to !torch.vtensor<*,f32> -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor<*,f32> -// CHECK: %[[MUTABLE_COPY:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor<*,f32> -// CHECK: torch.overwrite.tensor.contents %[[ARG0_ERASED]] overwrites %[[MUTABLE_COPY]] : !torch.vtensor<*,f32>, !torch.tensor<*,f32> -func.func @torch.overwrite.tensor.contents$static_overwrites_dynamic(%static: !torch.vtensor<[2],f32>, %dynamic: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { - %static_no_type = torch.tensor_static_info_cast %static : !torch.vtensor<[2],f32> to !torch.vtensor - %dynamic_no_type = torch.tensor_static_info_cast %dynamic : !torch.vtensor<[?],f32> to !torch.vtensor - %dynamic_copy = torch.copy.to_tensor %dynamic_no_type : !torch.tensor - torch.overwrite.tensor.contents %static_no_type overwrites %dynamic_copy : !torch.vtensor, !torch.tensor - %dynamic_value_copy = torch.copy.to_vtensor %dynamic_copy : !torch.vtensor - %result = torch.tensor_static_info_cast %dynamic_value_copy : !torch.vtensor to !torch.vtensor<[?],f32> - return %result : !torch.vtensor<[?],f32> -} - -// ----- -// CHECK-LABEL: func.func @bf16_result_type( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { -// CHECK: %[[SQRT:.*]] = torch.aten.sqrt %[[ARG0]] : !torch.vtensor<*,bf16> -> !torch.vtensor<[2],bf16> -// CHECK: return %[[SQRT]] : !torch.vtensor<[2],bf16> -func.func @bf16_result_type(%arg0: !torch.vtensor<*,bf16>) -> !torch.vtensor<[2],bf16> { - %1 = torch.aten.sqrt %arg0 : !torch.vtensor<*,bf16> -> !torch.vtensor<[2], bf16> - return %1 : !torch.vtensor<[2],bf16> -} - -// ----- -// CHECK-LABEL: func.func @propagate_scalar_type( -// CHECK-SAME: %[[INT:.*]]: !torch.int) -> !torch.number { -// CHECK: %[[NUM:.*]] = torch.derefine %[[INT]] : !torch.int to !torch.number -// CHECK: %[[ABS:.*]] = torch.prim.abs.Scalar %[[INT]] : !torch.int -> !torch.int -// CHECK: %[[RET:.*]] = torch.derefine %[[ABS]] : !torch.int to !torch.number -// CHECK: return %[[RET]] : !torch.number -func.func @propagate_scalar_type(%arg0: !torch.int) -> !torch.number { - %num = torch.derefine %arg0 : !torch.int to !torch.number - %1 = torch.prim.abs.Scalar %num: !torch.number -> !torch.number - return %1 : !torch.number -} - -// ----- -// CHECK-LABEL: func.func @prim.dtype( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor<*,bf16>) -> !torch.vtensor { - -// CHECK: %[[zero:.*]] = torch.constant.int 0 -// CHECK: %[[false:.*]] = torch.constant.bool false - -// CHECK: %[[neg:.*]] = torch.aten.neg %[[arg]] : !torch.vtensor<*,bf16> -> !torch.vtensor<*,bf16> -// CHECK: %[[dtype0:.*]] = torch.prim.dtype %[[neg]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device0:.*]] = torch.prim.device %[[neg]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[tensor:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype0]], %[[device0]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[dtype1:.*]] = torch.prim.dtype %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.int -// CHECK: %[[device1:.*]] = torch.prim.device %[[tensor]] : !torch.vtensor<*,bf16> -> !torch.Device -// CHECK: %[[result:.*]] = torch.aten.tensor.int %[[zero]], %[[dtype1]], %[[device1]], %[[false]] : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,bf16> - -// CHECK: %[[cast:.*]] = torch.tensor_static_info_cast %[[result]] : !torch.vtensor<*,bf16> to !torch.vtensor -// CHECK: return %[[cast]] : !torch.vtensor -// CHECK: } - -func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { - %zero = torch.constant.int 0 - %false = torch.constant.bool false - - // Op that requires type refinement - %neg = torch.aten.neg %arg : !torch.vtensor<*,bf16> -> !torch.vtensor<*,unk> - - // Op whose processing requires type refinement on its source argument. - %dtype = torch.prim.dtype %neg : !torch.vtensor<*,unk> -> !torch.int - %device = torch.prim.device %neg : !torch.vtensor<*,unk> -> !torch.Device - - // Another op that requires type refinement - %result = torch.aten.tensor.int %zero, %dtype, %device, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - // Repeat the above three ops a second time to ensure that the type refinement - // code works regardless of the number of alternating refinement+prim.dtype - // sequences. - %dtype2 = torch.prim.dtype %result : !torch.vtensor<*,unk> -> !torch.int - %device2 = torch.prim.device %result : !torch.vtensor<*,unk> -> !torch.Device - %result2 = torch.aten.tensor.int %zero, %dtype2, %device2, %false : !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<*,unk> - - return %result2 : !torch.vtensor<*,unk> -} - -// ----- - -// Check that we don't crash on this input. - -// CHECK-LABEL: func.func @forward -func.func @forward() -> !torch.vtensor { - %false = torch.constant.bool false - %none = torch.constant.none - %0 = torch.prim.ListConstruct : () -> !torch.list - // CHECK: torch.aten.tensor - %1 = torch.aten.tensor %0, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- - -// Check that we don't crash on this input. -// TODO: This appears to result in aten.mul.Tensor not being visited. -// We should investigate why that happens. - -// CHECK-LABEL: func.func @forward -func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { - %0 = torch.prim.If %arg0 -> (!torch.tensor) { - torch.prim.If.yield %arg1 : !torch.tensor - } else { - torch.prim.If.yield %arg1 : !torch.tensor - } - %1 = torch.copy.to_vtensor %0 : !torch.vtensor - %2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor - return -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.zeros_like( -// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" -// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32> -// CHECK: return -func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %int0 = torch.constant.int 0 - %cpu = torch.constant.device "cpu" - %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor - return -} - -// ----- - -// The data-flow analysis does not always propagate information to the entire graph. -// This results in some lattice elements being uninitialized, which must be properly -// handled when using the lattice elements to rewrite the graph. -// In this particular case, the presence of the loop causes `torch.copy.to_vtensor` -// to end up with an uninitialized lattice element. This is the simplest graph I was -// able to come up with that reproduces such behavior. - -// CHECK-LABEL: func.func @uninitialized_lattice_elements( -// CHECK: %{{.*}} = torch.copy.to_vtensor %{{.*}} : !torch.vtensor<*,f32> - -func.func @uninitialized_lattice_elements(%arg0: !torch.vtensor<*,f32>, %arg3: !torch.tensor) -> !torch.vtensor<*,f32> { - %true = torch.constant.bool true - %1 = torch.constant.int 0 - %2 = torch.prim.Loop %1, %true, init(%arg3) { - ^bb0(%arg1: !torch.int, %arg2: !torch.tensor): - torch.prim.Loop.condition %true, iter(%arg2 : !torch.tensor) - } : (!torch.int, !torch.bool, !torch.tensor) -> !torch.tensor - %3 = torch.tensor_static_info_cast %2 : !torch.tensor to !torch.tensor<*,f32> - %4 = torch.copy.to_vtensor %3 : !torch.vtensor<*,f32> - return %4 : !torch.vtensor<*,f32> -}