mirror of https://github.com/llvm/torch-mlir
Add miscellaneous dtype functions
This commit adds dtype functions for: - AtenAtan2Op - AtenLinearOp - AtenMaxPool2dWithIndicesOp - AtenCatOp - Aten_ShapeAsTensorOp - AtenScalarImplicitOp - PrimNumToTensorScalarOpdtype-functions-staging
parent
8e987d92bf
commit
423f94a1ae
|
@ -8156,6 +8156,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
@ -10042,6 +10048,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %4 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n"
|
||||
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int6 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
|
||||
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %2 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.len.t %arg0 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||
" torch.prim.Loop %4, %true, init() {\n"
|
||||
" ^bb0(%arg2: !torch.int):\n"
|
||||
" %6 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
|
||||
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.ScalarImplicit\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.union<float, int>) -> !torch.int {\n"
|
||||
" %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.union<float, int>) -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"";
|
||||
// clang-format on
|
||||
|
|
|
@ -520,21 +520,6 @@ static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
|
|||
return *result;
|
||||
}
|
||||
|
||||
static SmallVector<std::optional<bool>>
|
||||
getRankIsNonZeroArray(ValueRange values) {
|
||||
SmallVector<std::optional<bool>> rankIsNonZero;
|
||||
for (Value v : values) {
|
||||
if (auto tensorType = v.getType().dyn_cast<BaseTensorType>()) {
|
||||
if (tensorType.hasSizes()) {
|
||||
rankIsNonZero.push_back(tensorType.getSizes().size() != 0);
|
||||
} else {
|
||||
rankIsNonZero.push_back(std::nullopt);
|
||||
}
|
||||
}
|
||||
}
|
||||
return rankIsNonZero;
|
||||
}
|
||||
|
||||
// Normally, tensor dimensions need to be known at compile time to do type
|
||||
// promotion. `skipRankCheck`, when equal to true, can be used to indicate
|
||||
// special cases that tensor operands are guaranteed to be not zero dimension
|
||||
|
@ -597,64 +582,6 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
// Dtype is always float32, except for bfloat16, float64 and nullptr after
|
||||
// promotion and assuming possible-zero rank.
|
||||
if (isa<AtenAtan2Op>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Type promotedDtype = getPromotedResultType(
|
||||
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
|
||||
getRankIsNonZeroArray(op->getOperands()));
|
||||
if (promotedDtype) {
|
||||
knowledge.dtype = Float32Type::get(op->getContext());
|
||||
if (promotedDtype.isa<BFloat16Type, Float64Type>())
|
||||
knowledge.dtype = promotedDtype;
|
||||
}
|
||||
incorporateKnowledge(op->getResult(0), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto linear = llvm::dyn_cast<AtenLinearOp>(op)) {
|
||||
visitAtenLinearOp(linear, operands);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result0Knowledge.dtype = self.dtype;
|
||||
auto result1Knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
result1Knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
incorporateKnowledge(op->getResult(0), result0Knowledge);
|
||||
incorporateKnowledge(op->getResult(1), result1Knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto tensor = dyn_cast<AtenTensorOp>(op)) {
|
||||
visitAtenTensorOp(tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto cat = dyn_cast<AtenCatOp>(op)) {
|
||||
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
|
||||
return;
|
||||
} else if (auto stack = dyn_cast<AtenStackOp>(op)) {
|
||||
visitAtenCatLikeOp<AtenStackOp>(stack, operands);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto shapeAsTensor = dyn_cast<Aten_ShapeAsTensorOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
incorporateKnowledge(shapeAsTensor.getResult(), knowledge);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto embedding = dyn_cast<AtenEmbeddingOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
@ -683,11 +610,6 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||
visitNumToTensorOp(numToTensorOp);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isa<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp, AtenAddOp>(op)) {
|
||||
visitBinaryScalarOp(op, operands);
|
||||
return;
|
||||
|
|
|
@ -1546,6 +1546,11 @@ def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
|
||||
def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), ceil_mode: bool = False) -> Tuple[int, int]:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype, torch.int64
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
@ -2976,6 +2981,53 @@ def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T
|
|||
return torch.float64, self_dtype
|
||||
return self_dtype, self_dtype
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||
dtypes = [self_dtype, other_dtype]
|
||||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||
if is_integer_dtype(promoted_dtype):
|
||||
return torch.float32
|
||||
return promoted_dtype
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
weight_rank, weight_dtype = weight_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
|
||||
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
|
||||
Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32),
|
||||
NonZeroDTensorWithDtype(torch.complex64)])])
|
||||
def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) -> int:
|
||||
ranks: List[Optional[int]] = []
|
||||
dtypes: List[int] = []
|
||||
assert len(tensors_rank_dtype) != 0
|
||||
for tensor_rank_dtype in tensors_rank_dtype:
|
||||
tensor_rank, tensor_dtype = tensor_rank_dtype
|
||||
ranks.append(tensor_rank)
|
||||
dtypes.append(tensor_dtype)
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.int64
|
||||
|
||||
# Does not work on meta backend
|
||||
#@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[()]))
|
||||
def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
|
||||
a_rank, a_dtype = a_rank_dtype
|
||||
return a_dtype
|
||||
|
||||
@check_dtype_function([Invocation(0), Invocation(0.0)])
|
||||
def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float]) -> int:
|
||||
return get_dtype_of_scalar(a)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Main
|
||||
# ==============================================================================
|
||||
|
|
|
@ -3,73 +3,6 @@
|
|||
// This file is for tests for individual ops that require a new transfer
|
||||
// function (i.e. new code called from visitOperation).
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.linear(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[5,3],f32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[LINEAR:.*]] = torch.aten.linear %[[ARG0]], %[[ARG1]], %[[ARG2]] : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<*,f32>
|
||||
// CHECK: %[[RESULT:.*]] = torch.tensor_static_info_cast %[[LINEAR]] : !torch.vtensor<*,f32> to !torch.vtensor
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor
|
||||
func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.cat(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[?,1,4],f32>, !torch.tensor<[2,3,4],f32>) -> !torch.list<tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tensor<[2,3,4], f32>) -> !torch.tensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[?,1,4], f32>, !torch.tensor<[2,3,4], f32>) -> !torch.list<tensor>
|
||||
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.cat$promote_type(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[2,1,4],i1>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],si64>) -> !torch.tensor {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[TENSORS:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]] : (!torch.tensor<[2,1,4],i1>, !torch.tensor<[2,3,4],si64>) -> !torch.list<tensor>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.cat %[[TENSORS]], %[[INT1]] : !torch.list<tensor>, !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.cat$promote_type(%t0: !torch.tensor<[2,1,4], i1>, %t1: !torch.tensor<[2,3,4], si64>) -> !torch.tensor {
|
||||
%int1 = torch.constant.int 1
|
||||
%tensorList = torch.prim.ListConstruct %t0, %t1: (!torch.tensor<[2,1,4], i1>, !torch.tensor<[2,3,4], si64>) -> !torch.list<tensor>
|
||||
%ret = torch.aten.cat %tensorList, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -> !torch.tensor {
|
||||
%ret= torch.aten._shape_as_tensor %input : !torch.tensor<[?,1,4], f32> -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten._shape_as_tensor$unknown_input_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.tensor) -> !torch.tensor {
|
||||
%ret= torch.aten._shape_as_tensor %input : !torch.tensor -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.embedding(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
|
||||
|
@ -114,47 +47,3 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>,
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[NTT]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||
%0 = torch.prim.NumToTensor.Scalar %arg0: !torch.int -> !torch.tensor
|
||||
return %0: !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[NONE]], %[[NONE]], %[[FALSE]]
|
||||
// CHECK-SAME: : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool
|
||||
// CHECK-SAME: -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor(%t: !torch.list<list<float>>) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor %t, %none, %none, %false : !torch.list<list<float>>, !torch.none, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor$specified_dtype(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<list<float>>) -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT4:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[RET:.*]] = torch.aten.tensor %[[DATA]], %[[INT4]], %[[NONE]], %[[FALSE]] : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor<*,si64>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,si64> to !torch.tensor
|
||||
// CHECK: return %[[CAST]] : !torch.tensor
|
||||
func.func @torch.aten.tensor$specified_dtype(%t: !torch.list<list<float>>) -> !torch.tensor {
|
||||
%none = torch.constant.none
|
||||
%int4 = torch.constant.int 4
|
||||
%false = torch.constant.bool false
|
||||
%ret = torch.aten.tensor %t, %int4, %none, %false : !torch.list<list<float>>, !torch.int, !torch.none, !torch.bool -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue