Add dtype functions for floating point ops (#1813)

pull/1831/head
Jiahao Li 2023-01-21 02:39:41 +08:00 committed by GitHub
parent 8cae5ba507
commit 83d4e89d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 208 additions and 149 deletions

View File

@ -277,7 +277,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v
echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
}
function setup_venv() {

View File

@ -7350,6 +7350,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %arg0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %int3 = torch.constant.int 3\n"
" %int4 = torch.constant.int 4\n"
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
@ -7520,32 +7626,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int5 = torch.constant.int 5\n"
" %int15 = torch.constant.int 15\n"
" %true = torch.constant.bool true\n"
" %int7 = torch.constant.int 7\n"
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %4 : !torch.bool\n"
" }\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" }\n"
" return %3 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on

View File

@ -672,23 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
PrimsSqrtOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;
if (dtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
knowledge.dtype = dtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}
// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();

View File

@ -1026,19 +1026,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
dtype function instead of using this helper function.
"""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
]
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
"""Generate invocations for floating point only op."""
return [
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
]
def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
return input_dtype
return torch.float32
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atentanh〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenexp〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenexpm1〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atensin〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atencos〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atensigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenreciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atensqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenlog〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenlog2〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenlog1p〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenrsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenerf〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
def atensoftplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
def atenfrobenius_normdim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def primssqrt〡dtype(self_rank: int, self_dtype: int) -> int:
return _get_dtype_of_floating_point_op(self_dtype)
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenall〡dtype(self_rank: int, self_dtype: int) -> int:
@ -1167,13 +1255,6 @@ def atenfloor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int
def atenrsubScalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
def atenexpm1〡dtype(self_rank: int, self_dtype: int) -> int:
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
return self_dtype
else:
return torch.float32
# ==============================================================================
# Main
# ==============================================================================

View File

@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
return %ret: !torch.optional<tensor>
}
// -----
// CHECK-LABEL: func.func @f
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[CAST]] : !torch.vtensor
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
cf.br ^bb1(%cast: !torch.vtensor)
^bb1(%arg1: !torch.vtensor):
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @f
// CHECK: func.func private @callee
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
func.func @f() {
builtin.module {
func.func private @callee(%arg0: !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
return
}
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
call @callee(%cast) : (!torch.vtensor) -> ()
return
}
}
return
}

View File

@ -6,57 +6,6 @@
// Code for testing transfer functions for new ops (which is most changes)
// should go in refine-types-ops.mlir.
// -----
// CHECK-LABEL: func.func @basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @keep_existing_shape_information(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
return %1 : !torch.vtensor<[2],f32>
}
// -----
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[COS3]] : !torch.vtensor
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
return %3 : !torch.vtensor
}
// -----
// Check rewriting logic in case of mixes of users that do/don't allow type
// refinement.
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
return %1, %1 : !torch.vtensor, !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(