mirror of https://github.com/llvm/torch-mlir
[stablehlo] add dtype conversion when converting AtenScalarImplicitOp (#2439)
parent
3841fe3035
commit
c93c6970e8
|
@ -898,7 +898,16 @@ template <>
|
|||
LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
||||
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, adaptor.getA());
|
||||
Location loc = op.getLoc();
|
||||
Type inputDtype =
|
||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
auto result =
|
||||
rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue