mirror of https://github.com/llvm/torch-mlir
[stablehlo] add dtype conversion when converting AtenScalarImplicitOp (#2439)
parent
3841fe3035
commit
c93c6970e8
|
@ -898,7 +898,16 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
||||||
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
|
Location loc = op.getLoc();
|
||||||
|
Type inputDtype =
|
||||||
|
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||||
|
Type resultType =
|
||||||
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
|
auto result =
|
||||||
|
rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
||||||
|
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue