mirror of https://github.com/llvm/torch-mlir
[Stablehlo] simplify promoteType (#3525)
only provide `outElementType` when promoteTypepull/3530/head
parent
dcb48dd46c
commit
5bee9aac63
|
@ -50,7 +50,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
Operation *op, Value scalarValue, Type dtype);
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
TensorType outType);
|
||||
Type outElementType);
|
||||
|
||||
Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||
TensorType outType);
|
||||
|
|
|
@ -148,7 +148,8 @@ public:
|
|||
auto outType = cast<TensorType>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
self =
|
||||
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
@ -207,7 +208,8 @@ public:
|
|||
op.getType()));
|
||||
|
||||
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self,
|
||||
resultTy.getElementType());
|
||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||
return success();
|
||||
} else {
|
||||
|
@ -334,8 +336,8 @@ public:
|
|||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
|
||||
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
@ -381,8 +383,8 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||
|
||||
if (!skipMultiplyAlpha(op.getAlpha())) {
|
||||
Value alpha = hlo::scalarToStablehloTensor(rewriter, op,
|
||||
|
@ -437,8 +439,8 @@ public:
|
|||
}
|
||||
}
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||
auto loc = op.getLoc();
|
||||
Value result =
|
||||
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||
|
@ -476,16 +478,17 @@ public:
|
|||
if (isa<mlir::FloatType>(outElemTy))
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
else if (!outElemTy.isUnsignedInteger()) {
|
||||
TensorType defaultIntToFloatType =
|
||||
outType.cloneWith(outType.getShape(), rewriter.getF64Type());
|
||||
Type defaultIntToFloatType = rewriter.getF64Type();
|
||||
lhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType);
|
||||
rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType);
|
||||
result = rewriter.create<ChloOpT>(loc, defaultIntToFloatType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
result = rewriter.create<ChloOpT>(
|
||||
loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType),
|
||||
lhs, rhs, bcastDimensions);
|
||||
result = rewriter.create<stablehlo::FloorOp>(loc, result).getResult();
|
||||
result = hlo::promoteType(rewriter, op.getLoc(), result, outType);
|
||||
result = hlo::promoteType(rewriter, op.getLoc(), result,
|
||||
outType.getElementType());
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
|
@ -517,7 +520,8 @@ public:
|
|||
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
|
||||
rhs.getType());
|
||||
// use lhs's element type as compute type
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
||||
rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
|
||||
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||
}
|
||||
|
||||
|
@ -533,16 +537,16 @@ public:
|
|||
}
|
||||
|
||||
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||
} else if (isa<mlir::FloatType>(lhsElemTy) &&
|
||||
isa<mlir::IntegerType>(rhsElemTy)) {
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||
} else {
|
||||
if (lhsElemTy.getIntOrFloatBitWidth() >
|
||||
rhsElemTy.getIntOrFloatBitWidth()) {
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
|
||||
} else {
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
|
||||
}
|
||||
}
|
||||
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
|
||||
|
@ -622,11 +626,11 @@ public:
|
|||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()));
|
||||
Type outElemTy = outType.getElementType();
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||
if (!rhsTy) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||
}
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
|
@ -736,8 +740,10 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
|||
auto outType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
// promote self and other types
|
||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
||||
self =
|
||||
hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType());
|
||||
other =
|
||||
hlo::promoteType(rewriter, op.getLoc(), other, outType.getElementType());
|
||||
|
||||
if (failed(broadcastRanks(rewriter, op, self, cond)))
|
||||
return op.emitError("failed broadcast self and condition ranks");
|
||||
|
@ -940,8 +946,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||
}
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||
auto loc = op.getLoc();
|
||||
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
|
@ -977,8 +983,8 @@ LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
|
|||
lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy);
|
||||
}
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy);
|
||||
auto loc = op.getLoc();
|
||||
Value result = rewriter.create<chlo::BroadcastPowOp>(loc, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
|
@ -1121,7 +1127,8 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
|||
return op.emitError("only ranked tensor type is supported.");
|
||||
}
|
||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
input =
|
||||
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||
|
||||
auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||
auto log2Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), two);
|
||||
|
@ -1143,7 +1150,8 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
|||
}
|
||||
|
||||
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
input =
|
||||
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||
|
||||
auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||
auto log10Op = rewriter.create<stablehlo::LogOp>(op.getLoc(), ten);
|
||||
|
@ -1266,42 +1274,44 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
"non-bool cudnn_enabled unsupported");
|
||||
}
|
||||
if (training) {
|
||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
Type batchMeanOrVarTy =
|
||||
RankedTensorType::get(weightTy.getShape(), inputTy.getElementType());
|
||||
TensorType outputTy =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
Value output;
|
||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
Type computeType = inputTy.getElementType();
|
||||
if (weightTy.getElementType().getIntOrFloatBitWidth() >
|
||||
inputTy.getElementType().getIntOrFloatBitWidth()) {
|
||||
computeType = weightTy.getElementType();
|
||||
}
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, computeType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(inputTy.getShape(), computeType),
|
||||
RankedTensorType::get(weightTy.getShape(), computeType),
|
||||
RankedTensorType::get(weightTy.getShape(), computeType), input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = hlo::promoteType(rewriter, op.getLoc(),
|
||||
batchNormTrainingResult.getResult(0),
|
||||
cast<TensorType>(outputTy));
|
||||
outputTy.getElementType());
|
||||
} else {
|
||||
auto batchNormTrainingResult =
|
||||
rewriter.create<stablehlo::BatchNormTrainingOp>(
|
||||
op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input,
|
||||
weight, bias, rewriter.getF32FloatAttr(eps),
|
||||
op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias,
|
||||
rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = batchNormTrainingResult.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
} else {
|
||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||
TensorType outputTy =
|
||||
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
|
||||
inputTy.getShape().end()};
|
||||
castShape[1] = weightTy.getShape()[0];
|
||||
|
@ -1314,26 +1324,25 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
|||
Value output;
|
||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||
RankedTensorType convertedType = inputTy;
|
||||
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||
weightTy.getElementType());
|
||||
Type computeType = inputTy.getElementType();
|
||||
if (weightTy.getElementType().getIntOrFloatBitWidth() >
|
||||
inputTy.getElementType().getIntOrFloatBitWidth()) {
|
||||
computeType = weightTy.getElementType();
|
||||
}
|
||||
input =
|
||||
hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType);
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, computeType);
|
||||
weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType);
|
||||
runningMean =
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType);
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType);
|
||||
runningVar =
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType);
|
||||
hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType);
|
||||
Value bnResult = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), convertedType, input, weight, bias, runningMean,
|
||||
runningVar, rewriter.getF32FloatAttr(eps),
|
||||
op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType),
|
||||
input, weight, bias, runningMean, runningVar,
|
||||
rewriter.getF32FloatAttr(eps),
|
||||
rewriter.getI64IntegerAttr(feature_index));
|
||||
output = hlo::promoteType(rewriter, op.getLoc(), bnResult,
|
||||
cast<TensorType>(outputTy));
|
||||
outputTy.getElementType());
|
||||
} else {
|
||||
output = rewriter.create<stablehlo::BatchNormInferenceOp>(
|
||||
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||
|
@ -1515,7 +1524,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
|||
|
||||
// Promote type
|
||||
for (auto &v : builtinTensors) {
|
||||
v = hlo::promoteType(rewriter, op->getLoc(), v, outType);
|
||||
v = hlo::promoteType(rewriter, op->getLoc(), v, outType.getElementType());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<stablehlo::ConcatenateOp>(
|
||||
|
@ -1787,8 +1796,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
|||
auto outTy =
|
||||
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType());
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType());
|
||||
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastPowOp>(op, outTy, lhs, rhs,
|
||||
/*broadcast_attr*/ nullptr);
|
||||
|
@ -1961,8 +1970,10 @@ LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
|
|||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
|
||||
resultType.getElementType());
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
|
||||
resultType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
|
||||
return success();
|
||||
}
|
||||
|
@ -1979,8 +1990,10 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
|
|||
|
||||
auto resultType =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType);
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType);
|
||||
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
|
||||
resultType.getElementType());
|
||||
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
|
||||
resultType.getElementType());
|
||||
|
||||
stablehlo::MulOp mul;
|
||||
auto div = rewriter.create<stablehlo::DivOp>(loc, lhs, rhs);
|
||||
|
|
|
@ -835,7 +835,8 @@ public:
|
|||
llvm::to_vector<4>(llvm::seq<int64_t>(-nSpatialDims, 0));
|
||||
|
||||
bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims);
|
||||
bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy);
|
||||
bias =
|
||||
hlo::promoteType(rewriter, op.getLoc(), bias, outTy.getElementType());
|
||||
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
|
|
|
@ -522,7 +522,8 @@ public:
|
|||
} else {
|
||||
assert(false && "Unsupported pooling dimension");
|
||||
}
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor,
|
||||
outTy.getElementType());
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||
|
@ -532,8 +533,8 @@ public:
|
|||
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||
Value windowSizeConst =
|
||||
hlo::getConstTensor<float>(rewriter, op, {1.0}, {}).value();
|
||||
windowSizeConst =
|
||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||
windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst,
|
||||
outTy.getElementType());
|
||||
auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input);
|
||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), inputShapeVec);
|
||||
|
@ -583,7 +584,8 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
|||
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto outTy =
|
||||
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||
input =
|
||||
hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType());
|
||||
inputTy = cast<RankedTensorType>(input.getType());
|
||||
auto inputElemTy = inputTy.getElementType();
|
||||
auto inputRank = inputTy.getRank();
|
||||
|
|
|
@ -170,13 +170,10 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||
TensorType outType) {
|
||||
TensorType in_type = cast<TensorType>(input.getType());
|
||||
|
||||
if (in_type.getElementType() != outType.getElementType()) {
|
||||
TensorType promotedType =
|
||||
in_type.cloneWith(in_type.getShape(), outType.getElementType());
|
||||
return rewriter.create<stablehlo::ConvertOp>(loc, promotedType, input);
|
||||
Type outElementType) {
|
||||
TensorType inType = cast<TensorType>(input.getType());
|
||||
if (inType.getElementType() != outElementType) {
|
||||
return rewriter.create<stablehlo::ConvertOp>(loc, input, outElementType);
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue