[Torch Dialect] add missing one_hot dtype function (#2143)

* [Torch Dialect] add missing one_hot dtype function

* update

* update

* update
pull/2260/head snapshot-20230623.878
Yuanqiang Liu 2023-06-23 16:11:33 +08:00 committed by GitHub
parent 39201a4be5
commit 64afc08dab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 6 deletions

View File

@ -9548,6 +9548,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %7 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.one_hot\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -4313,7 +4313,6 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
return rewriter.notifyMatchFailure(
op, "unimplemented: num_classes must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
// arange tensor
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
@ -4341,11 +4340,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
loc, eqType, unsqueezeTensor, arangeTensor);
// convert to si64
Value si64TypeValue =
Torch::getDtypeIntValueForType(rewriter, loc, si64Type);
Value result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue,
/*copy=*/falseValue, /*memory_format=*/none);
Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type);
rewriter.replaceOp(op, result);
return success();
}

View File

@ -2699,6 +2699,13 @@ def atennative_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normaliz
result_dtype = torch.float64
return input_dtype, input_dtype, result_dtype
# note: one_hot doesn't support "meta" device, use "cpu" instead.
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, num_classes=2, tensor_device="cpu", error_types={torch.complex128, torch.complex64, torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}))
def atenone_hot〡dtype(self_rank_dtype: Tuple[int, int], num_classes: int = -1) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype == torch.int64
return torch.int64
@check_dtype_function(
[Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32),
TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32),