diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index d606749f8..cc8bf9c52 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -277,7 +277,7 @@ function test_in_tree() { python -m e2e_testing.main --config=lazy_tensor_core -v echo ":::: Run TorchDynamo e2e integration tests" - python -m e2e_testing.main --config=torchdynamo -v + python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic } function setup_venv() { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 88b0555ea..c0a0df5ee 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7350,6 +7350,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %int5 = torch.constant.int 5\n" +" %int15 = torch.constant.int 15\n" +" %int7 = torch.constant.int 7\n" +" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %arg0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union, %arg3: !torch.union) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list, !torch.int -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" +" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -7520,32 +7626,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list>, !torch.list) -> !torch.int\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" -" %int5 = torch.constant.int 5\n" -" %int15 = torch.constant.int 15\n" -" %true = torch.constant.bool true\n" -" %int7 = torch.constant.int 7\n" -" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" -" }\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %4 : !torch.bool\n" -" }\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %arg1 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" }\n" -" return %3 : !torch.int\n" -" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 414e37660..94436925e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -672,23 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op, return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } - // Dtype is always float32, except for bfloat16, float16, float64 and nullptr. - if (isa(op)) { - ValueKnowledge knowledge = - ValueKnowledge::getTensorPessimisticValueState(op->getContext()); - Type dtype = operands[0]->getValue().dtype; - if (dtype) { - knowledge.dtype = Float32Type::get(op->getContext()); - if (dtype.isa()) - knowledge.dtype = dtype; - } - incorporateKnowledge(op->getResult(0), knowledge); - return; - } - // Take dtype from second operand. if (isa(op)) { auto self = operands[1]->getValue(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 5502fe551..d95aba4b2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1026,19 +1026,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args): dtype function instead of using this helper function. """ return [ - Invocation(NonZeroDTensorWithDtype(torch.float32), *args), - Invocation(NonZeroDTensorWithDtype(torch.float64), *args), - Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), - Invocation(NonZeroDTensorWithDtype(torch.int64), *args), - Invocation(NonZeroDTensorWithDtype(torch.int32), *args), - Invocation(NonZeroDTensorWithDtype(torch.bool), *args), - Invocation(ZeroDTensorWithDtype(torch.float32), *args), - Invocation(ZeroDTensorWithDtype(torch.float64), *args), - Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), - Invocation(ZeroDTensorWithDtype(torch.int64), *args), - Invocation(ZeroDTensorWithDtype(torch.int32), *args), - Invocation(ZeroDTensorWithDtype(torch.bool), *args), -] + Invocation(NonZeroDTensorWithDtype(torch.float32), *args), + Invocation(NonZeroDTensorWithDtype(torch.float64), *args), + Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), + Invocation(NonZeroDTensorWithDtype(torch.int64), *args), + Invocation(NonZeroDTensorWithDtype(torch.int32), *args), + Invocation(NonZeroDTensorWithDtype(torch.bool), *args), + Invocation(ZeroDTensorWithDtype(torch.float32), *args), + Invocation(ZeroDTensorWithDtype(torch.float64), *args), + Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), + Invocation(ZeroDTensorWithDtype(torch.int64), *args), + Invocation(ZeroDTensorWithDtype(torch.int32), *args), + Invocation(ZeroDTensorWithDtype(torch.bool), *args), + ] + +def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args): + """Generate invocations for floating point only op.""" + return [ + Invocation(NonZeroDTensorWithDtype(torch.float32), *args), + Invocation(NonZeroDTensorWithDtype(torch.float64), *args), + Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args), + ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args), + ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args), + ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args), + Invocation(ZeroDTensorWithDtype(torch.float32), *args), + Invocation(ZeroDTensorWithDtype(torch.float64), *args), + Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args), + ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args), + ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args), + ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args), + ] + +def _get_dtype_of_floating_point_op(input_dtype: int) -> int: + if input_dtype in (torch.float64, torch.bfloat16, torch.float16): + return input_dtype + return torch.float32 + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by()) +def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int: + assert self_dtype not in (torch.int64, torch.int32, torch.bool) + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0])) +def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int: + assert self_dtype not in (torch.int64, torch.int32, torch.bool) + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) +def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int: + return _get_dtype_of_floating_point_op(self_dtype) @check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int: @@ -1167,13 +1255,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int: return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) -@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by()) -def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int: - if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16: - return self_dtype - else: - return torch.float32 - # ============================================================================== # Main # ============================================================================== diff --git a/test/Dialect/Torch/refine-types-branch.mlir b/test/Dialect/Torch/refine-types-branch.mlir index 87ff96576..3c76ac95f 100644 --- a/test/Dialect/Torch/refine-types-branch.mlir +++ b/test/Dialect/Torch/refine-types-branch.mlir @@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option } : (!torch.int, !torch.bool, !torch.optional) -> (!torch.optional) return %ret: !torch.optional } - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[CAST]] : !torch.vtensor -func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - cf.br ^bb1(%cast: !torch.vtensor) -^bb1(%arg1: !torch.vtensor): - %1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- - -// CHECK-LABEL: func.func @f -// CHECK: func.func private @callee -// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32> -func.func @f() { - builtin.module { - func.func private @callee(%arg0: !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor - return - } - func.func @caller(%arg0: !torch.vtensor<*,f32>) { - %cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor - call @callee(%cast) : (!torch.vtensor) -> () - return - } - } - return -} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 8e7a689f5..391fd6d2a 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -6,57 +6,6 @@ // Code for testing transfer functions for new ops (which is most changes) // should go in refine-types-ops.mlir. -// ----- -// CHECK-LABEL: func.func @basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[RESULT]] : !torch.vtensor -func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - return %1 : !torch.vtensor -} - -// ----- -// CHECK-LABEL: func.func @keep_existing_shape_information( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { -// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32> -// CHECK: return %[[COS]] : !torch.vtensor<[2],f32> -func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32> - return %1 : !torch.vtensor<[2],f32> -} - -// ----- -// CHECK-LABEL: func.func @propagate_through_multiple_ops( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: return %[[COS3]] : !torch.vtensor -func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - %3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor - return %3 : !torch.vtensor -} - -// ----- -// Check rewriting logic in case of mixes of users that do/don't allow type -// refinement. -// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { -// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor -// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32> -// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor -func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) { - %1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor - %3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor - return %1, %1 : !torch.vtensor, !torch.vtensor -} - // ----- // CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(