Add miscellaneous dtype functions

This commit adds dtype functions for:

- AtenAtan2Op
- AtenLinearOp
- AtenMaxPool2dWithIndicesOp
- AtenCatOp
- Aten_ShapeAsTensorOp
- AtenScalarImplicitOp
- PrimNumToTensorScalarOp
dtype-functions-staging
Ramiro Leal-Cavazos 2023-05-03 18:12:42 +00:00
parent 8e987d92bf
commit 423f94a1ae
4 changed files with 116 additions and 189 deletions

View File

@ -8156,6 +8156,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
@ -10042,6 +10048,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %4 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg2: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ScalarImplicit\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on

View File

@ -520,21 +520,6 @@ static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
return *result;
}
static SmallVector<std::optional<bool>>
getRankIsNonZeroArray(ValueRange values) {
SmallVector<std::optional<bool>> rankIsNonZero;
for (Value v : values) {
if (auto tensorType = v.getType().dyn_cast<BaseTensorType>()) {
if (tensorType.hasSizes()) {
rankIsNonZero.push_back(tensorType.getSizes().size() != 0);
} else {
rankIsNonZero.push_back(std::nullopt);
}
}
}
return rankIsNonZero;
}
// Normally, tensor dimensions need to be known at compile time to do type
// promotion. `skipRankCheck`, when equal to true, can be used to indicate
// special cases that tensor operands are guaranteed to be not zero dimension
@ -597,64 +582,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type promotedDtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
if (promotedDtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (promotedDtype.isa<BFloat16Type, Float64Type>())
knowledge.dtype = promotedDtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}
if (auto linear = llvm::dyn_cast<AtenLinearOp>(op)) {
visitAtenLinearOp(linear, operands);
return;
}
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype;
auto result1Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
incorporateKnowledge(op->getResult(0), result0Knowledge);
incorporateKnowledge(op->getResult(1), result1Knowledge);
return;
}
if (auto tensor = dyn_cast<AtenTensorOp>(op)) {
visitAtenTensorOp(tensor);
return;
}
if (auto cat = dyn_cast<AtenCatOp>(op)) {
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
return;
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
return;
}
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
return;
}
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
@ -683,11 +610,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
visitNumToTensorOp(numToTensorOp);
return;
}
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp, AtenAddOp>(op)) {
visitBinaryScalarOp(op, operands);
return;

View File

@ -1546,6 +1546,11 @@ def atenmax_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
def atenmax_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenmish〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
@ -2976,6 +2981,53 @@ def atenvar_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T
return torch.float64, self_dtype
return self_dtype, self_dtype
@check_dtype_function(_check_two_tensor_op())
def atenatan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
promoted_dtype = promote_dtypes(ranks, dtypes)
if is_integer_dtype(promoted_dtype):
return torch.float32
return promoted_dtype
@check_dtype_function(_check_two_tensor_op())
def atenlinear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int:
input_rank, input_dtype = input_rank_dtype
weight_rank, weight_dtype = weight_rank_dtype
return input_dtype
@check_dtype_function(
[Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
NonZeroDTensorWithDtype(torch.complex64)])])
def atencat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) -> int:
ranks: List[Optional[int]] = []
dtypes: List[int] = []
assert len(tensors_rank_dtype) != 0
for tensor_rank_dtype in tensors_rank_dtype:
tensor_rank, tensor_dtype = tensor_rank_dtype
ranks.append(tensor_rank)
dtypes.append(tensor_dtype)
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int64
# Does not work on meta backend
#@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[()]))
def atenScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
a_rank, a_dtype = a_rank_dtype
return a_dtype
@check_dtype_function([Invocation(0), Invocation(0.0)])
def primNumToTensorScalar〡dtype(a: Union[int, float]) -> int:
return get_dtype_of_scalar(a)
# ==============================================================================
# Main
# ==============================================================================

View File

@ -3,73 +3,6 @@
// This file is for tests for individual ops that require a new transfer
// function (i.e. new code called from visitOperation).
// -----
// CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
// CHECK: return %[[RESULT]] : !torch.vtensor
func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
%int1 = torch.constant.int 1
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
@ -114,47 +47,3 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>,
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
return %0: !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]]
// CHECK-SAME: : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool
// CHECK-SAME: -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
%none = torch.constant.none
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
return %ret : !torch.tensor
}