[stablehlo] add dtype conversion when converting AtenScalarImplicitOp (#2439)

pull/2441/head
Jiawei Wu 2023-09-06 01:57:15 +08:00 committed by GitHub
parent 3841fe3035
commit c93c6970e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 1 deletions

View File

@ -898,7 +898,16 @@ template <>
LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
AtenScalarImplicitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
Location loc = op.getLoc();
Type inputDtype =
op.getA().getType().template cast<BaseTensorType>().getDtype();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
auto result =
rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
rewriter.replaceOp(
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
return success();
}