mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] add missing one_hot dtype function (#2143)
* [Torch Dialect] add missing one_hot dtype function * update * update * updatepull/2260/head snapshot-20230623.878
parent
39201a4be5
commit
64afc08dab
|
@ -9548,6 +9548,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %7 = torch.prim.TupleConstruct %0#1, %0#1, %6 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||
" return %7 : !torch.tuple<int, int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.one_hot\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %1 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -4313,7 +4313,6 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: num_classes must be constant");
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value falseValue = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
|
||||
// arange tensor
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
|
@ -4341,11 +4340,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
|||
loc, eqType, unsqueezeTensor, arangeTensor);
|
||||
|
||||
// convert to si64
|
||||
Value si64TypeValue =
|
||||
Torch::getDtypeIntValueForType(rewriter, loc, si64Type);
|
||||
Value result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue,
|
||||
/*copy=*/falseValue, /*memory_format=*/none);
|
||||
Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -2699,6 +2699,13 @@ def aten〇native_layer_norm〡dtype(input_rank_dtype: Tuple[int, int], normaliz
|
|||
result_dtype = torch.float64
|
||||
return input_dtype, input_dtype, result_dtype
|
||||
|
||||
# note: one_hot doesn't support "meta" device, use "cpu" instead.
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, num_classes=2, tensor_device="cpu", error_types={torch.complex128, torch.complex64, torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}))
|
||||
def aten〇one_hot〡dtype(self_rank_dtype: Tuple[int, int], num_classes: int = -1) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
assert self_dtype == torch.int64
|
||||
return torch.int64
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32),
|
||||
TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, dtype=torch.float32),
|
||||
|
|
Loading…
Reference in New Issue