mirror of https://github.com/llvm/torch-mlir
parent
7f6b72aec8
commit
48383554da
|
@ -1108,6 +1108,8 @@ TOSA_PASS_SET = {
|
||||||
"TensorsConcatStaticModule_basic",
|
"TensorsConcatStaticModule_basic",
|
||||||
"TensorsConcatNegativeDimStaticModule_basic",
|
"TensorsConcatNegativeDimStaticModule_basic",
|
||||||
"AtenComplex64Module_basic",
|
"AtenComplex64Module_basic",
|
||||||
|
"ElementwiseSqrtIntModule_basic",
|
||||||
|
"ElementwiseSqrtModule_basic",
|
||||||
"SplitTensorGetItem_Module_basic",
|
"SplitTensorGetItem_Module_basic",
|
||||||
"SplitTensorListUnpackModule_basic",
|
"SplitTensorListUnpackModule_basic",
|
||||||
"ChunkListUnpack_Module_basic",
|
"ChunkListUnpack_Module_basic",
|
||||||
|
|
|
@ -4563,6 +4563,35 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
||||||
|
AtenSqrtOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Converts AtenSqrtOp into pow(x, 0.5)
|
||||||
|
auto self = adaptor.getSelf();
|
||||||
|
auto selfTy = self.getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
|
auto resultType = typeConverter->convertType(op.getType())
|
||||||
|
.template cast<RankedTensorType>();
|
||||||
|
auto elementType = resultType.getElementType();
|
||||||
|
|
||||||
|
if (isa<mlir::IntegerType>(selfTy.getElementType())) {
|
||||||
|
self = rewriter.createOrFold<tosa::CastOp>(
|
||||||
|
op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType),
|
||||||
|
self);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto oneHalf =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, elementType).value();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::PowOp>(op, resultType, self, oneHalf);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -4792,6 +4821,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenCatOp);
|
INSERT_ATENOP_PATTERN(AtenCatOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenSqrtOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
Loading…
Reference in New Issue