mirror of https://github.com/llvm/torch-mlir
parent
7f6b72aec8
commit
48383554da
|
@ -1108,6 +1108,8 @@ TOSA_PASS_SET = {
|
|||
"TensorsConcatStaticModule_basic",
|
||||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"AtenComplex64Module_basic",
|
||||
"ElementwiseSqrtIntModule_basic",
|
||||
"ElementwiseSqrtModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
|
|
|
@ -4563,6 +4563,35 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
||||
AtenSqrtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Converts AtenSqrtOp into pow(x, 0.5)
|
||||
auto self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().dyn_cast<TensorType>();
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Only Tensor types supported in TOSA");
|
||||
|
||||
auto resultType = typeConverter->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
auto elementType = resultType.getElementType();
|
||||
|
||||
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
||||
self = rewriter.createOrFold<tosa::CastOp>(
|
||||
op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType),
|
||||
self);
|
||||
}
|
||||
|
||||
auto oneHalf =
|
||||
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, elementType).value();
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::PowOp>(op, resultType, self, oneHalf);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -4792,6 +4821,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
Loading…
Reference in New Issue