mirror of https://github.com/llvm/torch-mlir
Add dtype functions for floating point ops (#1813)
parent
8cae5ba507
commit
83d4e89d25
|
@ -277,7 +277,7 @@ function test_in_tree() {
|
||||||
python -m e2e_testing.main --config=lazy_tensor_core -v
|
python -m e2e_testing.main --config=lazy_tensor_core -v
|
||||||
|
|
||||||
echo ":::: Run TorchDynamo e2e integration tests"
|
echo ":::: Run TorchDynamo e2e integration tests"
|
||||||
python -m e2e_testing.main --config=torchdynamo -v
|
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
|
||||||
}
|
}
|
||||||
|
|
||||||
function setup_venv() {
|
function setup_venv() {
|
||||||
|
|
|
@ -7350,6 +7350,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %4 : !torch.list<int>\n"
|
" return %4 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @__torch__._get_dtype_of_floating_point_op(%arg0: !torch.int) -> !torch.int {\n"
|
||||||
|
" %int6 = torch.constant.int 6\n"
|
||||||
|
" %int5 = torch.constant.int 5\n"
|
||||||
|
" %int15 = torch.constant.int 15\n"
|
||||||
|
" %int7 = torch.constant.int 7\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %int7, %int15, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %arg0 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %2 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.sigmoid\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.reciprocal\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.log\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.log2\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.erf\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.softplus\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %int11 = torch.constant.int 11\n"
|
||||||
|
" %int3 = torch.constant.int 3\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %3 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %int11 = torch.constant.int 11\n"
|
||||||
|
" %int3 = torch.constant.int 3\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %0 = torch.prim.ListConstruct %int4, %int3, %int11 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.__contains__.int_list %0, %arg1 : !torch.list<int>, !torch.int -> !torch.bool\n"
|
||||||
|
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %3 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.prims.sqrt\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
|
" %0 = call @__torch__._get_dtype_of_floating_point_op(%arg1) : (!torch.int) -> !torch.int\n"
|
||||||
|
" return %0 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
||||||
" %int11 = torch.constant.int 11\n"
|
" %int11 = torch.constant.int 11\n"
|
||||||
" return %int11 : !torch.int\n"
|
" return %int11 : !torch.int\n"
|
||||||
|
@ -7520,32 +7626,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %3 : !torch.int\n"
|
" return %3 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.int, %arg1: !torch.int) -> !torch.int {\n"
|
|
||||||
" %int6 = torch.constant.int 6\n"
|
|
||||||
" %int5 = torch.constant.int 5\n"
|
|
||||||
" %int15 = torch.constant.int 15\n"
|
|
||||||
" %true = torch.constant.bool true\n"
|
|
||||||
" %int7 = torch.constant.int 7\n"
|
|
||||||
" %0 = torch.aten.eq.int %arg1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
|
||||||
" torch.prim.If.yield %true : !torch.bool\n"
|
|
||||||
" } else {\n"
|
|
||||||
" %4 = torch.aten.eq.int %arg1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
|
||||||
" }\n"
|
|
||||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
|
||||||
" torch.prim.If.yield %true : !torch.bool\n"
|
|
||||||
" } else {\n"
|
|
||||||
" %4 = torch.aten.eq.int %arg1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
|
||||||
" torch.prim.If.yield %4 : !torch.bool\n"
|
|
||||||
" }\n"
|
|
||||||
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
|
|
||||||
" torch.prim.If.yield %arg1 : !torch.int\n"
|
|
||||||
" } else {\n"
|
|
||||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
|
||||||
" }\n"
|
|
||||||
" return %3 : !torch.int\n"
|
|
||||||
" }\n"
|
|
||||||
"}\n"
|
"}\n"
|
||||||
"";
|
"";
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -672,23 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
|
||||||
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
|
||||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
|
||||||
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
|
|
||||||
PrimsSqrtOp>(op)) {
|
|
||||||
ValueKnowledge knowledge =
|
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
|
||||||
Type dtype = operands[0]->getValue().dtype;
|
|
||||||
if (dtype) {
|
|
||||||
knowledge.dtype = Float32Type::get(op->getContext());
|
|
||||||
if (dtype.isa<BFloat16Type, Float16Type, Float64Type>())
|
|
||||||
knowledge.dtype = dtype;
|
|
||||||
}
|
|
||||||
incorporateKnowledge(op->getResult(0), knowledge);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take dtype from second operand.
|
// Take dtype from second operand.
|
||||||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
||||||
auto self = operands[1]->getValue();
|
auto self = operands[1]->getValue();
|
||||||
|
|
|
@ -1040,6 +1040,94 @@ def _get_invocations_for_op_with_tensor_arg_followed_by(*args):
|
||||||
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
|
Invocation(ZeroDTensorWithDtype(torch.bool), *args),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _get_invocations_for_fp_only_op_with_tensor_arg_followed_by(*args):
|
||||||
|
"""Generate invocations for floating point only op."""
|
||||||
|
return [
|
||||||
|
Invocation(NonZeroDTensorWithDtype(torch.float32), *args),
|
||||||
|
Invocation(NonZeroDTensorWithDtype(torch.float64), *args),
|
||||||
|
Invocation(NonZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||||
|
ErrorInvocation(NonZeroDTensorWithDtype(torch.int64), *args),
|
||||||
|
ErrorInvocation(NonZeroDTensorWithDtype(torch.int32), *args),
|
||||||
|
ErrorInvocation(NonZeroDTensorWithDtype(torch.bool), *args),
|
||||||
|
Invocation(ZeroDTensorWithDtype(torch.float32), *args),
|
||||||
|
Invocation(ZeroDTensorWithDtype(torch.float64), *args),
|
||||||
|
Invocation(ZeroDTensorWithDtype(torch.bfloat16), *args),
|
||||||
|
ErrorInvocation(ZeroDTensorWithDtype(torch.int64), *args),
|
||||||
|
ErrorInvocation(ZeroDTensorWithDtype(torch.int32), *args),
|
||||||
|
ErrorInvocation(ZeroDTensorWithDtype(torch.bool), *args),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_dtype_of_floating_point_op(input_dtype: int) -> int:
|
||||||
|
if input_dtype in (torch.float64, torch.bfloat16, torch.float16):
|
||||||
|
return input_dtype
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇exp〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇sin〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇cos〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇sigmoid〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇reciprocal〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇log〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇log2〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇log1p〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇rsqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇erf〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by())
|
||||||
|
def aten〇softplus〡dtype(self_rank: int, self_dtype: int, beta: Union[int, float] = 1, threshold: Union[int, float] = 20) -> int:
|
||||||
|
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_fp_only_op_with_tensor_arg_followed_by([0]))
|
||||||
|
def aten〇frobenius_norm〇dim〡dtype(self_rank: int, self_dtype: int, dim: List[int], keepdim: bool = False) -> int:
|
||||||
|
assert self_dtype not in (torch.int64, torch.int32, torch.bool)
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
|
def prims〇sqrt〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
||||||
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int:
|
def aten〇all〡dtype(self_rank: int, self_dtype: int) -> int:
|
||||||
return torch.bool
|
return torch.bool
|
||||||
|
@ -1167,13 +1255,6 @@ def aten〇floor_divide〡dtype(self_rank: int, self_dtype: int, other_rank: int
|
||||||
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
|
||||||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
||||||
|
|
||||||
@check_dtype_function(_get_invocations_for_op_with_tensor_arg_followed_by())
|
|
||||||
def aten〇expm1〡dtype(self_rank: int, self_dtype: int) -> int:
|
|
||||||
if self_dtype == torch.float64 or self_dtype == torch.bfloat16 or self_dtype == torch.float16:
|
|
||||||
return self_dtype
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Main
|
# Main
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
|
@ -117,37 +117,3 @@ func.func @prim.loop$region_arg_to_internal(%none: !torch.none) -> !torch.option
|
||||||
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
|
} : (!torch.int, !torch.bool, !torch.optional<tensor>) -> (!torch.optional<tensor>)
|
||||||
return %ret: !torch.optional<tensor>
|
return %ret: !torch.optional<tensor>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @f
|
|
||||||
// CHECK: %[[ATEN:.*]] = torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ATEN]] : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
|
||||||
func.func @f(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|
||||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
cf.br ^bb1(%cast: !torch.vtensor)
|
|
||||||
^bb1(%arg1: !torch.vtensor):
|
|
||||||
%1 = torch.aten.cos %arg1 : !torch.vtensor -> !torch.vtensor
|
|
||||||
return %1 : !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @f
|
|
||||||
// CHECK: func.func private @callee
|
|
||||||
// CHECK-NEXT: torch.aten.cos %{{.*}} : !torch.vtensor -> !torch.vtensor<*,f32>
|
|
||||||
func.func @f() {
|
|
||||||
builtin.module {
|
|
||||||
func.func private @callee(%arg0: !torch.vtensor) {
|
|
||||||
%1 = torch.aten.cos %arg0 : !torch.vtensor -> !torch.vtensor
|
|
||||||
return
|
|
||||||
}
|
|
||||||
func.func @caller(%arg0: !torch.vtensor<*,f32>) {
|
|
||||||
%cast = torch.tensor_static_info_cast %arg0 : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
call @callee(%cast) : (!torch.vtensor) -> ()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,57 +6,6 @@
|
||||||
// Code for testing transfer functions for new ops (which is most changes)
|
// Code for testing transfer functions for new ops (which is most changes)
|
||||||
// should go in refine-types-ops.mlir.
|
// should go in refine-types-ops.mlir.
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @basic(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|
||||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[COS]] : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
|
||||||
func.func @basic(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|
||||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
|
||||||
return %1 : !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @keep_existing_shape_information(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
|
||||||
// CHECK: %[[COS:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<[2],f32>
|
|
||||||
// CHECK: return %[[COS]] : !torch.vtensor<[2],f32>
|
|
||||||
func.func @keep_existing_shape_information(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor<[2],f32> {
|
|
||||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor<[2], f32>
|
|
||||||
return %1 : !torch.vtensor<[2],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// CHECK-LABEL: func.func @propagate_through_multiple_ops(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|
||||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[COS2:.*]] = torch.aten.cos %[[COS1]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[COS3:.*]] = torch.tensor_static_info_cast %[[COS2]] : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
// CHECK: return %[[COS3]] : !torch.vtensor
|
|
||||||
func.func @propagate_through_multiple_ops(%arg0: !torch.vtensor<*,f32>) -> !torch.vtensor {
|
|
||||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
|
||||||
%2 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
|
||||||
%3 = torch.aten.cos %2 : !torch.vtensor -> !torch.vtensor
|
|
||||||
return %3 : !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
// Check rewriting logic in case of mixes of users that do/don't allow type
|
|
||||||
// refinement.
|
|
||||||
// CHECK-LABEL: func.func @mixed_allowing_not_allowing_type_refinement(
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
|
||||||
// CHECK: %[[COS0:.*]] = torch.aten.cos %[[ARG0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: %[[ERASED:.*]] = torch.tensor_static_info_cast %[[COS0]] : !torch.vtensor<*,f32> to !torch.vtensor
|
|
||||||
// CHECK: %[[COS1:.*]] = torch.aten.cos %[[COS0]] : !torch.vtensor<*,f32> -> !torch.vtensor<*,f32>
|
|
||||||
// CHECK: return %[[ERASED]], %[[ERASED]] : !torch.vtensor, !torch.vtensor
|
|
||||||
func.func @mixed_allowing_not_allowing_type_refinement(%arg0: !torch.vtensor<*,f32>) -> (!torch.vtensor, !torch.vtensor) {
|
|
||||||
%1 = torch.aten.cos %arg0 : !torch.vtensor<*,f32> -> !torch.vtensor
|
|
||||||
%3 = torch.aten.cos %1 : !torch.vtensor -> !torch.vtensor
|
|
||||||
return %1, %1 : !torch.vtensor, !torch.vtensor
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
// CHECK-LABEL: func.func @torch.overwrite.tensor.contents$dynamic_overwrites_static(
|
||||||
|
|
Loading…
Reference in New Issue