mirror of https://github.com/llvm/torch-mlir
Add dtype functions for floating point ops (#1813)
parent
8cae5ba507
commit
83d4e89d25
|
@ -277,7 +277,7 @@ function test_in_tree() {
|
|||
python -m e2e_testing.main --config=lazy_tensor_core -v
|
||||
|
||||
echo ":::: Run TorchDynamo e2e integration tests"
|
||||
python -m e2e_testing.main --config=torchdynamo -v
|
||||
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
|
||||
}
|
||||
|
||||
function setup_venv() {
|
||||
|
|
|
@ -7350,6 +7350,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %4 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg0 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %int11 = torch.constant.int 11\n"
|
||||
" return %int11 : !torch.int\n"
|
||||
|
@ -7520,32 +7626,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int5 = torch.constant.int 5\n"
|
||||
" %int15 = torch.constant.int 15\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %3 : !torch.int\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -672,23 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
||||
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
||||
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
|
||||
PrimsSqrtOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
if (dtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
|
||||
knowledge.dtype = dtype;
|
||||
}
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
// Take dtype from second operand.
|
||||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||
auto self = operands[1]->getValue();
|
||||
|
|
|
@ -1026,19 +1026,107 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
|
|||
dtype function instead of using this helper function.
|
||||
"""
|
||||
return [
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
|
||||
]
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.int64), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.int32), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bool), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.int64), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.int32), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
|
||||
]
|
||||
|
||||
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
|
||||
"""Generate invocations for floating point only op."""
|
||||
return [
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
|
||||
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
|
||||
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
|
||||
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
|
||||
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
|
||||
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
|
||||
]
|
||||
|
||||
def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
|
||||
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
|
||||
return input_dtype
|
||||
return torch.float32
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
|
||||
def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
|
||||
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
|
||||
def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
|
||||
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
return _get_dtype_of_floating_point_op(self_dtype)
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
|
@ -1167,13 +1255,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int
|
|||
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
||||
|
||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
|
||||
return self_dtype
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
# ==============================================================================
|
||||
# Main
|
||||
# ==============================================================================
|
||||
|
|
|
@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
|
|||
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
|
||||
return %ret: !torch.optional<tensor>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @f
|
||||
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
cf.br ^bb1(%cast: !torch.vtensor)
|
||||
^bb1(%arg1: !torch.vtensor):
|
||||
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @f
|
||||
// CHECK: func.func private @callee
|
||||
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
||||
func.func @f() {
|
||||
builtin.module {
|
||||
func.func private @callee(%arg0: !torch.vtensor) {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
|
||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
call @callee(%cast) : (!torch.vtensor) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -6,57 +6,6 @@
|
|||
// Code for testing transfer functions for new ops (which is most changes)
|
||||
// should go in refine-types-ops.mlir.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @keep_existing_shape_information(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
||||
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
|
||||
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
|
||||
return %1 : !torch.vtensor<[2],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[COS3]] : !torch.vtensor
|
||||
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
|
||||
return %3 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
||||
// refinement.
|
||||
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
|
||||
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
||||
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||
|
|
Loading…
Reference in New Issue