From 64afc08dab903d4f531dd51056e94fdcb39cad95 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 23 Jun 2023 16:11:33 +0800 Subject: [PATCH] [Torch Dialect] add missing one_hot dtype function (#2143) * [Torch Dialect] add missing one_hot dtype function * update * update * update --- .../Torch/Transforms/AbstractInterpLibrary.cpp | 14 ++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 7 +------ .../jit_ir/build_tools/abstract_interp_lib_gen.py | 7 +++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index dbc2bc617..bea8f969e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9548,6 +9548,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %7 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.one_hot\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 77ab5489d..150a1f976 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4313,7 +4313,6 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "unimplemented: num_classes must be constant"); Value none = rewriter.create(loc); - Value falseValue = rewriter.create(loc, false); // arange tensor auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); @@ -4341,11 +4340,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 - Value si64TypeValue = - Torch::getDtypeIntValueForType(rewriter, loc, si64Type); - Value result = rewriter.create( - loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue, - /*copy=*/falseValue, /*memory_format=*/none); + Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); rewriter.replaceOp(op, result); return success(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 96f4a5ab3..216000970 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -2699,6 +2699,13 @@ def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normaliz result_dtype = torch.float64 return input_dtype, input_dtype, result_dtype +# note: one_hot doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, num_classes=2, tensor_device="cpu", error_types={torch.complex128, torch.complex64, torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool})) +def aten〇one_hot〡dtype(self_rank_dtype: Tuple[int, int], num_classes: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype == torch.int64 + return torch.int64 + @check_dtype_function( [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32),