From 8d25dd454f8d66a8e60ec5c541e586cb43cfda9a Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Tue, 18 Apr 2023 13:36:57 -0700 Subject: [PATCH] [TOSA] Add torch.prim.NumToTensor.Scalar float support (#1802) --- e2e_testing/xfail_sets.py | 6 ++++++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 87e79a5b5..5192422de 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -700,6 +700,12 @@ TOSA_PASS_SET = { "GatherStaticModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "ElementwiseWhereScalarModule_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFloat3D_basic", + "MaskedFillScalarDefaultModule_basic", + "NumToTensorFloatModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e4cdbd004..58bd2f8f3 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3718,13 +3718,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Only supports integer operand type, because for the floating point operand // type result tensor has to be of type `f64` which is not supported in the // tosa. - int64_t initValue; - if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue))) - return rewriter.notifyMatchFailure( - op, "unimplemented: input should be a torch constant int"); + double doubleValue; + auto isDouble = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue)); + int64_t intValue; + auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue)); + if (!isDouble && !isInt) + return rewriter.notifyMatchFailure(op, + "Unable to extract the scalar constant"); + + auto outElemTy = resultType.getElementType(); + if (outElemTy.isa()) { + rewriter.replaceOpWithNewOp(op, resultType, DenseElementsAttr::get(resultType, {intValue})); + } else if (outElemTy.isF64()) { + rewriter.replaceOpWithNewOp(op, resultType, DenseElementsAttr::get(resultType, {doubleValue})); + } - DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue}); - rewriter.replaceOpWithNewOp(op, resultType, constAttr); return success(); }