[onnx] Fix onnx.cast cases between int32 and int64 (#2982)

2 modifications:
1. torch.int64 is enum 4 in TORCH_DTYPE_TO_INT
2. add int32 support
pull/3032/head
Xinan Jiang(姜曦楠) 2024-03-16 01:14:09 +08:00 committed by GitHub
parent f34c187ac4
commit d8a52e82c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 4 deletions

View File

@ -43,8 +43,10 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
switch (dtypeIntOnnx) { switch (dtypeIntOnnx) {
case 1: case 1:
return 6; // float return 6; // float
case 6:
return 3; // int32
case 7: case 7:
return 5; // int64 return 4; // int64
case 9: case 9:
return 11; // bool return 11; // bool
case 10: case 10:

View File

@ -1883,10 +1883,7 @@ ONNX_XFAIL_SET = {
"BucketizeTensorOutInt32RightModule_basic", "BucketizeTensorOutInt32RightModule_basic",
"ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"HBC_basic",
"QuantizedMLP_basic", "QuantizedMLP_basic",
"TypeConversionI1ToI32Module_basic",
"TypeConversionI64ToI32Module_basic",
# Failure - onnx_lowering: onnx.Clip # Failure - onnx_lowering: onnx.Clip
"NormalizeModule_basic", "NormalizeModule_basic",