mirror of https://github.com/llvm/torch-mlir
parent
4d98f76d4f
commit
8d25dd454f
|
@ -700,6 +700,12 @@ TOSA_PASS_SET = {
|
|||
"GatherStaticModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorMultiIndexStaticModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"FullLikeModuleFloat3DStatic_basic",
|
||||
"FullModuleDefaultDtype_basic",
|
||||
"FullModuleFloat3D_basic",
|
||||
"MaskedFillScalarDefaultModule_basic",
|
||||
"NumToTensorFloatModule_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"ReduceSumDimIntListFloatModule_basic",
|
||||
|
|
|
@ -3718,13 +3718,21 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
|||
// Only supports integer operand type, because for the floating point operand
|
||||
// type result tensor has to be of type `f64` which is not supported in the
|
||||
// tosa.
|
||||
int64_t initValue;
|
||||
if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: input should be a torch constant int");
|
||||
double doubleValue;
|
||||
auto isDouble = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue));
|
||||
int64_t intValue;
|
||||
auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue));
|
||||
if (!isDouble && !isInt)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unable to extract the scalar constant");
|
||||
|
||||
auto outElemTy = resultType.getElementType();
|
||||
if (outElemTy.isa<mlir::IntegerType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, DenseElementsAttr::get(resultType, {intValue}));
|
||||
} else if (outElemTy.isF64()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, DenseElementsAttr::get(resultType, {doubleValue}));
|
||||
}
|
||||
|
||||
DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue});
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue