[TOSA] Add torch.prim.NumToTensor.Scalar float support (#1802)

pull/2047/head snapshot-20230419.813
Chi_Liu 2023-04-18 13:36:57 -07:00 committed by GitHub
parent 4d98f76d4f
commit 8d25dd454f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 6 deletions

View File

@ -700,6 +700,12 @@ TOSA_PASS_SET = {
"GatherStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"ElementwiseWhereScalarModule_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFloat3D_basic",
"MaskedFillScalarDefaultModule_basic",
"NumToTensorFloatModule_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDimIntListFloatModule_basic",

View File

@ -3718,13 +3718,21 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
// Only supports integer operand type, because for the floating point operand
// type result tensor has to be of type `f64` which is not supported in the
// tosa.
int64_t initValue;
if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue)))
return rewriter.notifyMatchFailure(
op, "unimplemented: input should be a torch constant int");
double doubleValue;
auto isDouble = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue));
int64_t intValue;
auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue));
if (!isDouble && !isInt)
return rewriter.notifyMatchFailure(op,
"Unable to extract the scalar constant");
auto outElemTy = resultType.getElementType();
if (outElemTy.isa<mlir::IntegerType>()) {
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, DenseElementsAttr::get(resultType, {intValue}));
} else if (outElemTy.isF64()) {
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, DenseElementsAttr::get(resultType, {doubleValue}));
}
DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
return success();
}