mirror of https://github.com/llvm/torch-mlir
Add dtype functions for ops that take dtype from 2nd operand (#1891)
parent
63945a2fd4
commit
ce7abf4911
|
@ -7687,6 +7687,88 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
|
||||||
" return %2 : !torch.int\n"
|
" return %2 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n"
|
||||||
|
" %int5 = torch.constant.int 5\n"
|
||||||
|
" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n"
|
||||||
|
" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %4 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %6 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %7 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n"
|
||||||
|
" %int5 = torch.constant.int 5\n"
|
||||||
|
" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n"
|
||||||
|
" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %4 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
|
||||||
|
" torch.prim.If %6 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %7 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %int11 = torch.constant.int 11\n"
|
" %int11 = torch.constant.int 11\n"
|
||||||
" %int0 = torch.constant.int 0\n"
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
|
|
@ -672,16 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take dtype from second operand.
|
|
||||||
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
|
|
||||||
auto self = operands[1]->getValue();
|
|
||||||
auto knowledge =
|
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
|
||||||
knowledge.dtype = self.dtype;
|
|
||||||
incorporateKnowledge(op->getResult(0), knowledge);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dtype is always si64.
|
// Dtype is always si64.
|
||||||
if (isa<AtenBincountOp>(op)) {
|
if (isa<AtenBincountOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
|
|
|
@ -1238,6 +1238,36 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
assert self_dtype != torch.float16
|
assert self_dtype != torch.float16
|
||||||
return _get_dtype_of_floating_point_op(self_dtype)
|
return _get_dtype_of_floating_point_op(self_dtype)
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||||
|
None, [(3,), (3, 4)],
|
||||||
|
{torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool},
|
||||||
|
TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) +
|
||||||
|
[ErrorInvocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)),
|
||||||
|
ErrorInvocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))])
|
||||||
|
def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype"
|
||||||
|
assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype"
|
||||||
|
assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype"
|
||||||
|
assert self_dtype != torch.float16, "`self` cannot have float16 dtype"
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(
|
||||||
|
None, [(2, 4, 7, 6), (2, 4, 6, 5)],
|
||||||
|
{torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool},
|
||||||
|
[2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) +
|
||||||
|
[ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)),
|
||||||
|
ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))])
|
||||||
|
def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype"
|
||||||
|
assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype"
|
||||||
|
assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype"
|
||||||
|
assert self_dtype != torch.float16, "`self` cannot have float16 dtype"
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
Loading…
Reference in New Issue