diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3..e90f231c7 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {