TorchToTosa: Legalization for torch.aten.sqrt (#2234)

pull/2308/head snapshot-20230714.899
Tiago Trevisan Jost 2023-07-14 08:23:10 +02:00 committed by GitHub
parent 7f6b72aec8
commit 48383554da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 0 deletions

View File

@ -1108,6 +1108,8 @@ TOSA_PASS_SET = {
"TensorsConcatStaticModule_basic", "TensorsConcatStaticModule_basic",
"TensorsConcatNegativeDimStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic",
"AtenComplex64Module_basic", "AtenComplex64Module_basic",
"ElementwiseSqrtIntModule_basic",
"ElementwiseSqrtModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic", "SplitTensorListUnpackModule_basic",
"ChunkListUnpack_Module_basic", "ChunkListUnpack_Module_basic",

View File

@ -4563,6 +4563,35 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
AtenSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Converts AtenSqrtOp into pow(x, 0.5)
auto self = adaptor.getSelf();
auto selfTy = self.getType().dyn_cast<TensorType>();
if (!selfTy)
return rewriter.notifyMatchFailure(op,
"Only Tensor types supported in TOSA");
auto resultType = typeConverter->convertType(op.getType())
.template cast<RankedTensorType>();
auto elementType = resultType.getElementType();
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
self = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType),
self);
}
auto oneHalf =
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, elementType).value();
rewriter.replaceOpWithNewOp<tosa::PowOp>(op, resultType, self, oneHalf);
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -4792,6 +4821,7 @@ public:
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \